Skip to content

Conversation

@zixi-qi
Copy link

@zixi-qi zixi-qi commented Nov 20, 2025

Description

Implement changes described in #1112 to use FP8_e5m2 automatically when using quantized kv cache FP8 on trillium.

Tests

  • unit test
(tpu-inference) qizixi@t1v-n-00a74f4e-w-0:~/tpu-inference$ pytest tests/platforms/test_tpu_platform.py
======================================================== test session starts ========================================================
platform linux -- Python 3.12.12, pytest-9.0.1, pluggy-1.6.0
rootdir: /home/qizixi/tpu-inference
configfile: pyproject.toml
plugins: anyio-4.11.0, mock-3.15.1, jaxtyping-0.3.3
collected 2 items                                                                                                                   

tests/platforms/test_tpu_platform.py ..                                                                                       [100%]

========================================================= warnings summary ==========================================================
.venv/lib/python3.12/site-packages/tpu_info/device.py:32
  /home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/tpu_info/device.py:32: DeprecationWarning: In 3.13 classes created inside an enum will not become a member.  Use the `member` decorator to keep the current behavior.
    class Info(typing.NamedTuple):

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================================== 2 passed, 3 warnings in 4.30s ===================================================
(tpu-
  • e2e test is blocked by a kernel issue since it appears that FP8 KV is not supported (same error happens with or without this change)
vllm bench throughput --model meta-llama/Llama-3.1-8B --tensor-parallel-size 1 --dtype bfloat16 --kv-cache-dtype fp8 --max-model-len 4096 --max-num-seqs 128 --num-prompts 100 --dataset-name random --input-len 1024 --output-len 100

(EngineCore_DP0 pid=331491)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2225, in _convert_element_type_lowering_rule
(EngineCore_DP0 pid=331491)     return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
(EngineCore_DP0 pid=331491)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=331491)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1077, in f_lowered
(EngineCore_DP0 pid=331491)     out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
(EngineCore_DP0 pid=331491)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=331491)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=331491)     raise
(EngineCore_DP0 pid=331491)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2208, in _convert_element_type_lowering_rule
(EngineCore_DP0 pid=331491)     raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
(EngineCore_DP0 pid=331491) NotImplementedError: Unsupported cast: uint16 -> uint32

Checklist

Before submitting this PR, please make sure:

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have made or will make corresponding changes to any relevant documentation.

@kyuyeunk
Copy link
Collaborator

e2e test is blocked by a kernel issue since it appears that FP8 KV is not supported (same error happens with or without this change)

Also, this shouldn't be the case? Can you verify again?

Signed-off-by: zixi-qi <qizixi@meta.com>
@zixi-qi
Copy link
Author

zixi-qi commented Nov 20, 2025

e2e test is blocked by a kernel issue since it appears that FP8 KV is not supported (same error happens with or without this change)

Also, this shouldn't be the case? Can you verify again?

I verified multiple times and still have the issue. Here is my setup:

commit: ada5c211acf838261a4f378382b3a34e6eef9fde
hardware: INFO 11-20 21:10:13 [__init__.py:25] TPU info: node_name=qizixi-tpu-v6e | tpu_type=v6e-8 | worker_id=0 | num_chips=8 | num_cores_per_chip=1

full server log:

(tpu-inference) qizixi@t1v-n-00a74f4e-w-0:~/tpu-inference$ vllm bench throughput --model meta-llama/Llama-3.1-8B --tensor-parallel-size 1 --dtype bfloat16 --kv-cache-dtype fp8 --max-model-len 4096 --max-num-seqs 128 --num-prompts 100 --dataset-name random --input-len 1024 --output-len 100
INFO 11-20 21:10:13 [__init__.py:25] TPU info: node_name=qizixi-tpu-v6e | tpu_type=v6e-8 | worker_id=0 | num_chips=8 | num_cores_per_chip=1
INFO 11-20 21:10:15 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-20 21:10:15 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-20 21:10:15 [interface.py:201] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
INFO 11-20 21:10:15 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-20 21:10:15 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-20 21:10:15 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-20 21:10:17 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=2048.
When dataset path is not set, it will default to random dataset
INFO 11-20 21:10:17 [datasets.py:613] Sampling input_len from [1023, 1023] and output_len from [100, 100]
INFO 11-20 21:10:18 [utils.py:253] non-default args: {'tokenizer': 'meta-llama/Llama-3.1-8B', 'dtype': 'bfloat16', 'kv_cache_dtype': 'fp8', 'seed': 0, 'max_model_len': 4096, 'num_redundant_experts': None, 'eplb_window_size': None, 'eplb_step_interval': None, 'eplb_log_balancedness': None, 'max_num_seqs': 128, 'enable_lora': None, 'reasoning_parser_plugin': '', 'model': 'meta-llama/Llama-3.1-8B'}
INFO 11-20 21:10:19 [model.py:644] Resolved architecture: LlamaForCausalLM
INFO 11-20 21:10:19 [model.py:1769] Using max model len 4096
INFO 11-20 21:10:19 [cache.py:195] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor.
INFO 11-20 21:10:19 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 11-20 21:10:19 [tpu_platform.py:127] Initialized sharding configuration: ShardingConfigManager(total_devices=1, sharding_strategy=ShardingStrategy(tensor_parallelism=1, expert_parallelism=1, sequence_parallelism=1, data_parallelism=1, attention_data_parallelism=1), device_indexes=None)
WARNING 11-20 21:10:19 [tpu_platform.py:164] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
INFO 11-20 21:10:19 [tpu_platform.py:198] Force using UniProcExecutor for JAX on                         single host without pipeline parallelism.
WARNING 11-20 21:10:20 [tpu_platform.py:239] Pin memory is not supported on TPU.
INFO 11-20 21:10:22 [__init__.py:25] TPU info: node_name=qizixi-tpu-v6e | tpu_type=v6e-8 | worker_id=0 | num_chips=8 | num_cores_per_chip=1
INFO 11-20 21:10:23 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-20 21:10:23 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-20 21:10:23 [interface.py:201] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
INFO 11-20 21:10:23 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-20 21:10:23 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-20 21:10:23 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:25 [core.py:93] Initializing a V1 LLM engine (v0.1.dev11447+gcb0a7b4be) with config: model='meta-llama/Llama-3.1-8B', speculative_config=None, tokenizer='meta-llama/Llama-3.1-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=<class 'jax.numpy.bfloat16'>, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=fp8, device_config=None, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=meta-llama/Llama-3.1-8B, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'level': None, 'mode': <CompilationMode.DYNAMO_TRACE_ONCE: 2>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'openxla', 'custom_ops': ['all'], 'splitting_ops': None, 'compile_mm_encoder': False, 'use_inductor': None, 'compile_sizes': None, 'inductor_compile_config': {'enable_auto_functionalized_v2': False}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': None, 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {}, 'max_cudagraph_capture_size': None, 'local_cache_dir': None}
(EngineCore_DP0 pid=409600) WARNING 11-20 21:10:25 [tpu_platform.py:239] Pin memory is not supported on TPU.
(EngineCore_DP0 pid=409600) WARNING 11-20 21:10:25 [tpu_worker.py:87] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:35 [parallel_state.py:1217] world_size=1 rank=0 local_rank=0 distributed_init_method=file:///tmp/tmpdv5_y65k backend=gloo
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:35 [parallel_state.py:1425] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank 0
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:35 [tpu_runner.py:303] Init mesh | mesh=Mesh('data': 1, 'model': 1, axis_types=(Auto, Auto))
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:35 [utils.py:93] Prepared token paddings: [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:35 [utils.py:59] Prepared request paddings: [8, 16, 32, 64, 128]
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:35 [compilation_manager.py:34] Enabling JAX compile cache.
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:35 [tpu_worker.py:246] Init worker | rank=0 | node_id=0 | is_driver_worker=True | hbm=[(0.0, 31.25)]GiB
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:35 [model_loader.py:319] Loading model with MODEL_IMPL_TYPE=flax_nnx
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:36 [weight_utils.py:119] Downloading weights from HF meta-llama/Llama-3.1-8B
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:36 [weight_utils.py:141] Loading weights from /home/qizixi/.cache/huggingface/hub/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00001-of-00004.safetensors
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:36 [weight_utils.py:141] Loading weights from /home/qizixi/.cache/huggingface/hub/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00002-of-00004.safetensors
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:36 [weight_utils.py:141] Loading weights from /home/qizixi/.cache/huggingface/hub/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00003-of-00004.safetensors
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:36 [weight_utils.py:141] Loading weights from /home/qizixi/.cache/huggingface/hub/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00004-of-00004.safetensors
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:40 [tpu_runner.py:527] Init model | hbm=[(14.96, 31.25)]GiB
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:40 [tpu_worker.py:275] Memory statistics | total_hbm_limit_gb=31.25GiB | total_hbm_limit_cap_gb=28.12GiB | total_hbm_used_gb=14.96GiB | total_hbm_avail_gb=13.16GiB
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:40 [kv_cache_utils.py:1234] GPU KV cache size: 215,552 tokens
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:40 [kv_cache_utils.py:1239] Maximum concurrency for 4,096 tokens per request: 52.62x
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:42 [kv_cache_manager.py:215] Init kv-cache | num_layers=32 | shape=(num_blocks, (256, 4, 4, 128)) | num_blocks=[842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842] | sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('data', None, 'model'), memory_kind=device) | dtype=float8_e4m3fn | hbm=[(28.12, 31.25)]Gb
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:42 [compilation_manager.py:73] Precompile all the subgraphs with possible input shapes.
(EngineCore_DP0 pid=409600) INFO 11-20 21:10:42 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 16}
(EngineCore_DP0 pid=409600) WARNING 11-20 21:10:42 [tuned_block_sizes.py:4077] Couldn`t find tuned sizes for the RPA v3 kernel with ('TPU v6e', 256, 'q_bfloat16_kv_float8_e4m3fn', 'q_head-32_kv_head-8_head-128', 4096)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] EngineCore failed to start.
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] Traceback (most recent call last):
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "<string>", line 1, in <module>
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/usr/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     exitcode = _main(fd, parent_sentinel)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/usr/lib/python3.12/multiprocessing/spawn.py", line 135, in _main
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return self._bootstrap(parent_sentinel)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self.run()
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self._target(*self._args, **self._kwargs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 834, in run_engine_core
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 607, in __init__
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     super().__init__(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 109, in __init__
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 248, in _initialize_kv_caches
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self.model_executor.initialize_from_config(kv_cache_configs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/executor/abstract.py", line 116, in initialize_from_config
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self.collective_rpc("compile_or_warm_up_model")
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/executor/uniproc_executor.py", line 75, in collective_rpc
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     result = run_method(self.driver_worker, method, args, kwargs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/serial_utils.py", line 479, in run_method
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return func(*args, **kwargs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/worker/tpu_worker.py", line 361, in compile_or_warm_up_model
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self.model_runner.capture_model()
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/tpu_runner.py", line 544, in capture_model
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self.compilation_manager.capture_model()
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 76, in capture_model
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self._precompile_backbone_text_only()
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 251, in _precompile_backbone_text_only
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self._precompile_backbone_helper("backbone",
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 182, in _precompile_backbone_helper
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self._run_compilation(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 59, in _run_compilation
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     result = fn(*args)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 172, in model_fn_wrapper
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     kv_caches, hidden_states, _ = self.runner.model_fn(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/models/common/model_loader.py", line 224, in run_model
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return model(*args)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 323, in __call__
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     kv_caches, x, aux_hidden_states = self.model(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 292, in __call__
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     kv_cache, x = layer(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 213, in __call__
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     kv_cache, attn_output = self.self_attn(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 157, in __call__
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     new_kv_cache, outputs = attention(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/layers/common/attention_interface.py", line 372, in attention
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     output, kv_cache = sharded_ragged_paged_attention(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/layers/common/attention_interface.py", line 331, in sharded_ragged_paged_attention
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return shard_map.shard_map(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/layers/common/attention_interface.py", line 322, in _ragged_paged_attention
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return func(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 1473, in ragged_paged_attention
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/usr/lib/python3.12/contextlib.py", line 81, in inner
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return func(*args, **kwds)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1715, in wrapped
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     jaxpr, consts = _trace_kernel_to_jaxpr(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1210, in _trace_kernel_to_jaxpr
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/primitives.py", line 874, in wrap_with_transforms
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return f(*new_args)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 881, in _ragged_paged_attention_kernel
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     @pl.when(seq_idx < decode_end)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/helpers.py", line 70, in _wrapped
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     jax.lax.cond(condition, f, lambda: None)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 883, in process_decode
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     process(static_q_len=1)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 872, in process
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     lax.fori_loop(0, num_bq, compute_with_bq, None, unroll=False)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 849, in compute_with_bq
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 825, in compute_with_bkv
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     bkv_lst = strided_load_bkv(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 723, in strided_load_bkv
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return _convert_to_target_bitwidth(kv, target_bitwidth=bitwidth)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 713, in _convert_to_target_bitwidth
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     left_out = _convert_to_target_bitwidth(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 699, in _convert_to_target_bitwidth
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     left = val.astype(next_dtype)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 1089, in meth
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return getattr(self.aval, name).fun(self, *args, **kwargs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 122, in _astype
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return lax_numpy.astype(self, dtype, copy=copy, device=device)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 5370, in astype
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     result = lax_internal._convert_element_type(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Unsupported cast: uint16 -> uint32
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] 
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] 
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] --------------------
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] 
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] The above exception was the direct cause of the following exception:
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] 
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] Traceback (most recent call last):
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 834, in run_engine_core
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 607, in __init__
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     super().__init__(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 109, in __init__
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 248, in _initialize_kv_caches
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self.model_executor.initialize_from_config(kv_cache_configs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/executor/abstract.py", line 116, in initialize_from_config
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self.collective_rpc("compile_or_warm_up_model")
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/executor/uniproc_executor.py", line 75, in collective_rpc
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     result = run_method(self.driver_worker, method, args, kwargs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/serial_utils.py", line 479, in run_method
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return func(*args, **kwargs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/worker/tpu_worker.py", line 361, in compile_or_warm_up_model
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self.model_runner.capture_model()
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/tpu_runner.py", line 544, in capture_model
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self.compilation_manager.capture_model()
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 76, in capture_model
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self._precompile_backbone_text_only()
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 251, in _precompile_backbone_text_only
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self._precompile_backbone_helper("backbone",
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 182, in _precompile_backbone_helper
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     self._run_compilation(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 59, in _run_compilation
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     result = fn(*args)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]              ^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 172, in model_fn_wrapper
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     kv_caches, hidden_states, _ = self.runner.model_fn(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]                                   ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1319, in _pallas_call_lowering
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return mlir.lower_per_platform(ctx, "pallas_call",
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1292, in tpu_lowering
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 161, in pallas_call_tpu_lowering_rule
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     mosaic_module, extra_args = lower_module(for_verification=False)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 150, in lower_module
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return lowering.lower_jaxpr_to_module(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 736, in lower_jaxpr_to_module
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     func_op = lower_jaxpr_to_func(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]               ^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1050, in lower_jaxpr_to_func
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jaxlib/mlir/dialects/func.py", line 197, in decorator
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return_values = f(*func_args, **func_kwargs)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1046, in body_func
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return jaxpr_subcomp(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     raise
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3188, in _cond_lowering_rule
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     raise
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3058, in _scan_lowering_rule
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     out = _lower_jaxpr_to_for_loop(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]           ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3000, in _lower_jaxpr_to_for_loop
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     args = _run_body(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2983, in _run_body
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     raise
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3110, in _while_lowering_rule
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return _lower_while_via_fori(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3080, in _lower_while_via_fori
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     for_out = _lower_jaxpr_to_for_loop(
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]               ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3015, in _lower_jaxpr_to_for_loop
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     inner_out = _run_body(iv, inner_args)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]                 ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2983, in _run_body
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     raise
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2225, in _convert_element_type_lowering_rule
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1077, in f_lowered
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     raise
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2208, in _convert_element_type_lowering_rule
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843]     raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] NotImplementedError: Unsupported cast: uint16 -> uint32
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] --------------------
(EngineCore_DP0 pid=409600) ERROR 11-20 21:10:43 [core.py:843] For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
(EngineCore_DP0 pid=409600) Process EngineCore_DP0:
(EngineCore_DP0 pid=409600) Traceback (most recent call last):
(EngineCore_DP0 pid=409600)   File "<string>", line 1, in <module>
(EngineCore_DP0 pid=409600)   File "/usr/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
(EngineCore_DP0 pid=409600)     exitcode = _main(fd, parent_sentinel)
(EngineCore_DP0 pid=409600)   File "/usr/lib/python3.12/multiprocessing/spawn.py", line 135, in _main
(EngineCore_DP0 pid=409600)     return self._bootstrap(parent_sentinel)
(EngineCore_DP0 pid=409600)   File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_DP0 pid=409600)     self.run()
(EngineCore_DP0 pid=409600)   File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_DP0 pid=409600)     self._target(*self._args, **self._kwargs)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 834, in run_engine_core
(EngineCore_DP0 pid=409600)     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 607, in __init__
(EngineCore_DP0 pid=409600)     super().__init__(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 109, in __init__
(EngineCore_DP0 pid=409600)     num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 248, in _initialize_kv_caches
(EngineCore_DP0 pid=409600)     self.model_executor.initialize_from_config(kv_cache_configs)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/executor/abstract.py", line 116, in initialize_from_config
(EngineCore_DP0 pid=409600)     self.collective_rpc("compile_or_warm_up_model")
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/executor/uniproc_executor.py", line 75, in collective_rpc
(EngineCore_DP0 pid=409600)     result = run_method(self.driver_worker, method, args, kwargs)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/serial_utils.py", line 479, in run_method
(EngineCore_DP0 pid=409600)     return func(*args, **kwargs)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/worker/tpu_worker.py", line 361, in compile_or_warm_up_model
(EngineCore_DP0 pid=409600)     self.model_runner.capture_model()
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/tpu_runner.py", line 544, in capture_model
(EngineCore_DP0 pid=409600)     self.compilation_manager.capture_model()
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 76, in capture_model
(EngineCore_DP0 pid=409600)     self._precompile_backbone_text_only()
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 251, in _precompile_backbone_text_only
(EngineCore_DP0 pid=409600)     self._precompile_backbone_helper("backbone",
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 182, in _precompile_backbone_helper
(EngineCore_DP0 pid=409600)     self._run_compilation(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 59, in _run_compilation
(EngineCore_DP0 pid=409600)     result = fn(*args)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 172, in model_fn_wrapper
(EngineCore_DP0 pid=409600)     kv_caches, hidden_states, _ = self.runner.model_fn(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/models/common/model_loader.py", line 224, in run_model
(EngineCore_DP0 pid=409600)     return model(*args)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 323, in __call__
(EngineCore_DP0 pid=409600)     kv_caches, x, aux_hidden_states = self.model(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 292, in __call__
(EngineCore_DP0 pid=409600)     kv_cache, x = layer(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 213, in __call__
(EngineCore_DP0 pid=409600)     kv_cache, attn_output = self.self_attn(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 157, in __call__
(EngineCore_DP0 pid=409600)     new_kv_cache, outputs = attention(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/layers/common/attention_interface.py", line 372, in attention
(EngineCore_DP0 pid=409600)     output, kv_cache = sharded_ragged_paged_attention(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/layers/common/attention_interface.py", line 331, in sharded_ragged_paged_attention
(EngineCore_DP0 pid=409600)     return shard_map.shard_map(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/layers/common/attention_interface.py", line 322, in _ragged_paged_attention
(EngineCore_DP0 pid=409600)     return func(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 1473, in ragged_paged_attention
(EngineCore_DP0 pid=409600)     output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache)
(EngineCore_DP0 pid=409600)   File "/usr/lib/python3.12/contextlib.py", line 81, in inner
(EngineCore_DP0 pid=409600)     return func(*args, **kwds)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1715, in wrapped
(EngineCore_DP0 pid=409600)     jaxpr, consts = _trace_kernel_to_jaxpr(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1210, in _trace_kernel_to_jaxpr
(EngineCore_DP0 pid=409600)     jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/primitives.py", line 874, in wrap_with_transforms
(EngineCore_DP0 pid=409600)     return f(*new_args)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 881, in _ragged_paged_attention_kernel
(EngineCore_DP0 pid=409600)     @pl.when(seq_idx < decode_end)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/helpers.py", line 70, in _wrapped
(EngineCore_DP0 pid=409600)     jax.lax.cond(condition, f, lambda: None)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 883, in process_decode
(EngineCore_DP0 pid=409600)     process(static_q_len=1)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 872, in process
(EngineCore_DP0 pid=409600)     lax.fori_loop(0, num_bq, compute_with_bq, None, unroll=False)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 849, in compute_with_bq
(EngineCore_DP0 pid=409600)     lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 825, in compute_with_bkv
(EngineCore_DP0 pid=409600)     bkv_lst = strided_load_bkv(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 723, in strided_load_bkv
(EngineCore_DP0 pid=409600)     return _convert_to_target_bitwidth(kv, target_bitwidth=bitwidth)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 713, in _convert_to_target_bitwidth
(EngineCore_DP0 pid=409600)     left_out = _convert_to_target_bitwidth(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 699, in _convert_to_target_bitwidth
(EngineCore_DP0 pid=409600)     left = val.astype(next_dtype)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 1089, in meth
(EngineCore_DP0 pid=409600)     return getattr(self.aval, name).fun(self, *args, **kwargs)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 122, in _astype
(EngineCore_DP0 pid=409600)     return lax_numpy.astype(self, dtype, copy=copy, device=device)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 5370, in astype
(EngineCore_DP0 pid=409600)     result = lax_internal._convert_element_type(
(EngineCore_DP0 pid=409600) jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Unsupported cast: uint16 -> uint32
(EngineCore_DP0 pid=409600) 
(EngineCore_DP0 pid=409600) The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
(EngineCore_DP0 pid=409600) 
(EngineCore_DP0 pid=409600) --------------------
(EngineCore_DP0 pid=409600) 
(EngineCore_DP0 pid=409600) The above exception was the direct cause of the following exception:
(EngineCore_DP0 pid=409600) 
(EngineCore_DP0 pid=409600) Traceback (most recent call last):
(EngineCore_DP0 pid=409600)   File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_DP0 pid=409600)     self.run()
(EngineCore_DP0 pid=409600)   File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_DP0 pid=409600)     self._target(*self._args, **self._kwargs)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 847, in run_engine_core
(EngineCore_DP0 pid=409600)     raise e
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 834, in run_engine_core
(EngineCore_DP0 pid=409600)     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=409600)                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 607, in __init__
(EngineCore_DP0 pid=409600)     super().__init__(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 109, in __init__
(EngineCore_DP0 pid=409600)     num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
(EngineCore_DP0 pid=409600)                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 248, in _initialize_kv_caches
(EngineCore_DP0 pid=409600)     self.model_executor.initialize_from_config(kv_cache_configs)
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/executor/abstract.py", line 116, in initialize_from_config
(EngineCore_DP0 pid=409600)     self.collective_rpc("compile_or_warm_up_model")
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/executor/uniproc_executor.py", line 75, in collective_rpc
(EngineCore_DP0 pid=409600)     result = run_method(self.driver_worker, method, args, kwargs)
(EngineCore_DP0 pid=409600)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/vllm/vllm/v1/serial_utils.py", line 479, in run_method
(EngineCore_DP0 pid=409600)     return func(*args, **kwargs)
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/worker/tpu_worker.py", line 361, in compile_or_warm_up_model
(EngineCore_DP0 pid=409600)     self.model_runner.capture_model()
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/tpu_runner.py", line 544, in capture_model
(EngineCore_DP0 pid=409600)     self.compilation_manager.capture_model()
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 76, in capture_model
(EngineCore_DP0 pid=409600)     self._precompile_backbone_text_only()
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 251, in _precompile_backbone_text_only
(EngineCore_DP0 pid=409600)     self._precompile_backbone_helper("backbone",
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 182, in _precompile_backbone_helper
(EngineCore_DP0 pid=409600)     self._run_compilation(
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 59, in _run_compilation
(EngineCore_DP0 pid=409600)     result = fn(*args)
(EngineCore_DP0 pid=409600)              ^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 172, in model_fn_wrapper
(EngineCore_DP0 pid=409600)     kv_caches, hidden_states, _ = self.runner.model_fn(
(EngineCore_DP0 pid=409600)                                   ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1319, in _pallas_call_lowering
(EngineCore_DP0 pid=409600)     return mlir.lower_per_platform(ctx, "pallas_call",
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1292, in tpu_lowering
(EngineCore_DP0 pid=409600)     return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 161, in pallas_call_tpu_lowering_rule
(EngineCore_DP0 pid=409600)     mosaic_module, extra_args = lower_module(for_verification=False)
(EngineCore_DP0 pid=409600)                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 150, in lower_module
(EngineCore_DP0 pid=409600)     return lowering.lower_jaxpr_to_module(
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 736, in lower_jaxpr_to_module
(EngineCore_DP0 pid=409600)     func_op = lower_jaxpr_to_func(
(EngineCore_DP0 pid=409600)               ^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1050, in lower_jaxpr_to_func
(EngineCore_DP0 pid=409600)     body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jaxlib/mlir/dialects/func.py", line 197, in decorator
(EngineCore_DP0 pid=409600)     return_values = f(*func_args, **func_kwargs)
(EngineCore_DP0 pid=409600)                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1046, in body_func
(EngineCore_DP0 pid=409600)     return jaxpr_subcomp(
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=409600)     raise
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3188, in _cond_lowering_rule
(EngineCore_DP0 pid=409600)     out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
(EngineCore_DP0 pid=409600)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=409600)     raise
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3058, in _scan_lowering_rule
(EngineCore_DP0 pid=409600)     out = _lower_jaxpr_to_for_loop(
(EngineCore_DP0 pid=409600)           ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3000, in _lower_jaxpr_to_for_loop
(EngineCore_DP0 pid=409600)     args = _run_body(
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2983, in _run_body
(EngineCore_DP0 pid=409600)     args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=409600)     raise
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3110, in _while_lowering_rule
(EngineCore_DP0 pid=409600)     return _lower_while_via_fori(
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3080, in _lower_while_via_fori
(EngineCore_DP0 pid=409600)     for_out = _lower_jaxpr_to_for_loop(
(EngineCore_DP0 pid=409600)               ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3015, in _lower_jaxpr_to_for_loop
(EngineCore_DP0 pid=409600)     inner_out = _run_body(iv, inner_args)
(EngineCore_DP0 pid=409600)                 ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2983, in _run_body
(EngineCore_DP0 pid=409600)     args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=409600)     raise
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2225, in _convert_element_type_lowering_rule
(EngineCore_DP0 pid=409600)     return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
(EngineCore_DP0 pid=409600)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1077, in f_lowered
(EngineCore_DP0 pid=409600)     out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
(EngineCore_DP0 pid=409600)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=409600)     raise
(EngineCore_DP0 pid=409600)   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2208, in _convert_element_type_lowering_rule
(EngineCore_DP0 pid=409600)     raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
(EngineCore_DP0 pid=409600) NotImplementedError: Unsupported cast: uint16 -> uint32
(EngineCore_DP0 pid=409600) --------------------
(EngineCore_DP0 pid=409600) For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

@kyuyeunk Would you mind sharing some suggestions on how to debug this issue further? Thanks in advance!

@kyuyeunk
Copy link
Collaborator

Thanks for investigating this.

I do have some vague guess on where the issue is coming from.

Instead of using --kv-cache-dtype fp8, can you first verify if the test still fails with --kv-cache-dtype fp8_5m2? This will isolate the problem if it's the issue with supporting fp8_e5m2 at all or due to some problem with fp8 to fp8_e5m2 conversion.

@zixi-qi
Copy link
Author

zixi-qi commented Nov 20, 2025

Thanks for investigating this.

I do have some vague guess on where the issue is coming from.

Instead of using --kv-cache-dtype fp8, can you first verify if the test still fails with --kv-cache-dtype fp8_5m2? This will isolate the problem if it's the issue with supporting fp8_e5m2 at all or due to some problem with fp8 to fp8_e5m2 conversion.

Noob question, is --kv-cache-dtype fp8_5m2 a typo?

vllm bench <bench_type> [options] throughput: error: argument --kv-cache-dtype: invalid choice: 'fp8_5m2' (choose from auto, bfloat16, fp8, fp8_ds_mla, fp8_e4m3, fp8_e5m2, fp8_inc)

But anyways even after reverting this change, I still see the same error:

(tpu-inference) qizixi@t1v-n-00a74f4e-w-0:~/tpu-inference$ vllm bench throughput --model meta-llama/Llama-3.1-8B --tensor-parallel-size 1 --dtype bfloat16 --kv-cache-dtype fp8 --max-model-len 4096 --max-num-seqs 128 --num-prompts 100 --dataset-name random --input-len 1024 --output-len 100
INFO 11-20 23:44:09 [__init__.py:25] TPU info: node_name=qizixi-tpu-v6e | tpu_type=v6e-8 | worker_id=0 | num_chips=8 | num_cores_per_chip=1
INFO 11-20 23:44:11 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-20 23:44:11 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-20 23:44:11 [interface.py:201] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
INFO 11-20 23:44:13 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=2048.
When dataset path is not set, it will default to random dataset
INFO 11-20 23:44:13 [datasets.py:613] Sampling input_len from [1023, 1023] and output_len from [100, 100]
INFO 11-20 23:44:14 [utils.py:253] non-default args: {'tokenizer': 'meta-llama/Llama-3.1-8B', 'dtype': 'bfloat16', 'kv_cache_dtype': 'fp8', 'seed': 0, 'max_model_len': 4096, 'num_redundant_experts': None, 'eplb_window_size': None, 'eplb_step_interval': None, 'eplb_log_balancedness': None, 'max_num_seqs': 128, 'enable_lora': None, 'reasoning_parser_plugin': '', 'model': 'meta-llama/Llama-3.1-8B'}
INFO 11-20 23:44:15 [model.py:644] Resolved architecture: LlamaForCausalLM
INFO 11-20 23:44:15 [model.py:1769] Using max model len 4096
INFO 11-20 23:44:15 [cache.py:195] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor.
INFO 11-20 23:44:15 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 11-20 23:44:15 [tpu_platform.py:119] Initialized sharding configuration: ShardingConfigManager(total_devices=1, sharding_strategy=ShardingStrategy(tensor_parallelism=1, expert_parallelism=1, sequence_parallelism=1, data_parallelism=1, attention_data_parallelism=1), device_indexes=None)
WARNING 11-20 23:44:15 [tpu_platform.py:156] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
INFO 11-20 23:44:15 [tpu_platform.py:190] Force using UniProcExecutor for JAX on                         single host without pipeline parallelism.
WARNING 11-20 23:44:16 [tpu_platform.py:231] Pin memory is not supported on TPU.
INFO 11-20 23:44:18 [__init__.py:25] TPU info: node_name=qizixi-tpu-v6e | tpu_type=v6e-8 | worker_id=0 | num_chips=8 | num_cores_per_chip=1
INFO 11-20 23:44:19 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-20 23:44:19 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-20 23:44:19 [interface.py:201] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:21 [core.py:93] Initializing a V1 LLM engine (v0.1.dev11447+gcb0a7b4be) with config: model='meta-llama/Llama-3.1-8B', speculative_config=None, tokenizer='meta-llama/Llama-3.1-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=<class 'jax.numpy.bfloat16'>, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=fp8, device_config=None, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=meta-llama/Llama-3.1-8B, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'level': None, 'mode': <CompilationMode.DYNAMO_TRACE_ONCE: 2>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'openxla', 'custom_ops': ['all'], 'splitting_ops': None, 'compile_mm_encoder': False, 'use_inductor': None, 'compile_sizes': None, 'inductor_compile_config': {'enable_auto_functionalized_v2': False}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': None, 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {}, 'max_cudagraph_capture_size': None, 'local_cache_dir': None}
(EngineCore_DP0 pid=439628) WARNING 11-20 23:44:21 [tpu_platform.py:231] Pin memory is not supported on TPU.
(EngineCore_DP0 pid=439628) WARNING 11-20 23:44:21 [tpu_worker.py:87] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:31 [parallel_state.py:1217] world_size=1 rank=0 local_rank=0 distributed_init_method=file:///tmp/tmp217loa8f backend=gloo
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:31 [parallel_state.py:1425] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, PCP rank 0, TP rank 0, EP rank 0
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:31 [tpu_runner.py:303] Init mesh | mesh=Mesh('data': 1, 'model': 1, axis_types=(Auto, Auto))
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:31 [utils.py:93] Prepared token paddings: [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:31 [utils.py:59] Prepared request paddings: [8, 16, 32, 64, 128]
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:31 [compilation_manager.py:34] Enabling JAX compile cache.
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:31 [tpu_worker.py:246] Init worker | rank=0 | node_id=0 | is_driver_worker=True | hbm=[(0.0, 31.25)]GiB
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:31 [model_loader.py:319] Loading model with MODEL_IMPL_TYPE=flax_nnx
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:32 [weight_utils.py:119] Downloading weights from HF meta-llama/Llama-3.1-8B
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:32 [weight_utils.py:141] Loading weights from /home/qizixi/.cache/huggingface/hub/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00001-of-00004.safetensors
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:32 [weight_utils.py:141] Loading weights from /home/qizixi/.cache/huggingface/hub/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00002-of-00004.safetensors
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:32 [weight_utils.py:141] Loading weights from /home/qizixi/.cache/huggingface/hub/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00003-of-00004.safetensors
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:32 [weight_utils.py:141] Loading weights from /home/qizixi/.cache/huggingface/hub/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00004-of-00004.safetensors
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:37 [tpu_runner.py:527] Init model | hbm=[(14.96, 31.25)]GiB
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:37 [tpu_worker.py:275] Memory statistics | total_hbm_limit_gb=31.25GiB | total_hbm_limit_cap_gb=28.12GiB | total_hbm_used_gb=14.96GiB | total_hbm_avail_gb=13.16GiB
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:37 [kv_cache_utils.py:1234] GPU KV cache size: 215,552 tokens
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:37 [kv_cache_utils.py:1239] Maximum concurrency for 4,096 tokens per request: 52.62x
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:38 [kv_cache_manager.py:215] Init kv-cache | num_layers=32 | shape=(num_blocks, (256, 4, 4, 128)) | num_blocks=[842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842] | sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('data', None, 'model'), memory_kind=device) | dtype=float8_e4m3fn | hbm=[(28.12, 31.25)]Gb
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:38 [compilation_manager.py:73] Precompile all the subgraphs with possible input shapes.
(EngineCore_DP0 pid=439628) INFO 11-20 23:44:38 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 16}
(EngineCore_DP0 pid=439628) WARNING 11-20 23:44:38 [tuned_block_sizes.py:4077] Couldn`t find tuned sizes for the RPA v3 kernel with ('TPU v6e', 256, 'q_bfloat16_kv_float8_e4m3fn', 'q_head-32_kv_head-8_head-128', 4096)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] EngineCore failed to start.
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] Traceback (most recent call last):
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "<string>", line 1, in <module>
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/usr/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     exitcode = _main(fd, parent_sentinel)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/usr/lib/python3.12/multiprocessing/spawn.py", line 135, in _main
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return self._bootstrap(parent_sentinel)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self.run()
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self._target(*self._args, **self._kwargs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 834, in run_engine_core
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 607, in __init__
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     super().__init__(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 109, in __init__
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 248, in _initialize_kv_caches
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self.model_executor.initialize_from_config(kv_cache_configs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/executor/abstract.py", line 116, in initialize_from_config
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self.collective_rpc("compile_or_warm_up_model")
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/executor/uniproc_executor.py", line 75, in collective_rpc
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     result = run_method(self.driver_worker, method, args, kwargs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/serial_utils.py", line 479, in run_method
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return func(*args, **kwargs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/worker/tpu_worker.py", line 361, in compile_or_warm_up_model
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self.model_runner.capture_model()
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/tpu_runner.py", line 544, in capture_model
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self.compilation_manager.capture_model()
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 76, in capture_model
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self._precompile_backbone_text_only()
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 251, in _precompile_backbone_text_only
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self._precompile_backbone_helper("backbone",
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 182, in _precompile_backbone_helper
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self._run_compilation(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 59, in _run_compilation
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     result = fn(*args)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 172, in model_fn_wrapper
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     kv_caches, hidden_states, _ = self.runner.model_fn(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/models/common/model_loader.py", line 224, in run_model
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return model(*args)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 323, in __call__
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     kv_caches, x, aux_hidden_states = self.model(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 292, in __call__
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     kv_cache, x = layer(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 213, in __call__
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     kv_cache, attn_output = self.self_attn(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/models/jax/llama3.py", line 157, in __call__
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     new_kv_cache, outputs = attention(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/layers/common/attention_interface.py", line 372, in attention
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     output, kv_cache = sharded_ragged_paged_attention(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/layers/common/attention_interface.py", line 331, in sharded_ragged_paged_attention
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return shard_map.shard_map(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/layers/common/attention_interface.py", line 322, in _ragged_paged_attention
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return func(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 1473, in ragged_paged_attention
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     output, updated_kv_cache = kernel(*scalar_prefetches, q, kv, kv_cache)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/usr/lib/python3.12/contextlib.py", line 81, in inner
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return func(*args, **kwds)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1715, in wrapped
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     jaxpr, consts = _trace_kernel_to_jaxpr(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1210, in _trace_kernel_to_jaxpr
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/primitives.py", line 874, in wrap_with_transforms
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return f(*new_args)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 881, in _ragged_paged_attention_kernel
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     @pl.when(seq_idx < decode_end)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/helpers.py", line 70, in _wrapped
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     jax.lax.cond(condition, f, lambda: None)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 883, in process_decode
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     process(static_q_len=1)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 872, in process
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     lax.fori_loop(0, num_bq, compute_with_bq, None, unroll=False)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 849, in compute_with_bq
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     lax.fori_loop(0, num_bkv, compute_with_bkv, None, unroll=False)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 825, in compute_with_bkv
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     bkv_lst = strided_load_bkv(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 723, in strided_load_bkv
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return _convert_to_target_bitwidth(kv, target_bitwidth=bitwidth)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 713, in _convert_to_target_bitwidth
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     left_out = _convert_to_target_bitwidth(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/kernels/ragged_paged_attention/v3/kernel.py", line 699, in _convert_to_target_bitwidth
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     left = val.astype(next_dtype)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 1089, in meth
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return getattr(self.aval, name).fun(self, *args, **kwargs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 122, in _astype
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return lax_numpy.astype(self, dtype, copy=copy, device=device)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 5370, in astype
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     result = lax_internal._convert_element_type(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Unsupported cast: uint16 -> uint32
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] 
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] 
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] --------------------
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] 
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] The above exception was the direct cause of the following exception:
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] 
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] Traceback (most recent call last):
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 834, in run_engine_core
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 607, in __init__
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     super().__init__(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 109, in __init__
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/engine/core.py", line 248, in _initialize_kv_caches
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self.model_executor.initialize_from_config(kv_cache_configs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/executor/abstract.py", line 116, in initialize_from_config
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self.collective_rpc("compile_or_warm_up_model")
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/executor/uniproc_executor.py", line 75, in collective_rpc
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     result = run_method(self.driver_worker, method, args, kwargs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/vllm/vllm/v1/serial_utils.py", line 479, in run_method
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return func(*args, **kwargs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/worker/tpu_worker.py", line 361, in compile_or_warm_up_model
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self.model_runner.capture_model()
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/tpu_runner.py", line 544, in capture_model
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self.compilation_manager.capture_model()
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 76, in capture_model
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self._precompile_backbone_text_only()
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 251, in _precompile_backbone_text_only
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self._precompile_backbone_helper("backbone",
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 182, in _precompile_backbone_helper
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     self._run_compilation(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 59, in _run_compilation
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     result = fn(*args)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]              ^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/tpu_inference/runner/compilation_manager.py", line 172, in model_fn_wrapper
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     kv_caches, hidden_states, _ = self.runner.model_fn(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]                                   ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1319, in _pallas_call_lowering
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return mlir.lower_per_platform(ctx, "pallas_call",
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/pallas_call.py", line 1292, in tpu_lowering
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return mosaic_tpu_backend.pallas_call_tpu_lowering_rule(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 161, in pallas_call_tpu_lowering_rule
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     mosaic_module, extra_args = lower_module(for_verification=False)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/pallas_call_registration.py", line 150, in lower_module
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return lowering.lower_jaxpr_to_module(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 736, in lower_jaxpr_to_module
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     func_op = lower_jaxpr_to_func(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]               ^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1050, in lower_jaxpr_to_func
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jaxlib/mlir/dialects/func.py", line 197, in decorator
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return_values = f(*func_args, **func_kwargs)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1046, in body_func
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return jaxpr_subcomp(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     raise
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3188, in _cond_lowering_rule
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     out = jaxpr_subcomp(lowering_context, branches[1].jaxpr, *args)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     raise
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3058, in _scan_lowering_rule
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     out = _lower_jaxpr_to_for_loop(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]           ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3000, in _lower_jaxpr_to_for_loop
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     args = _run_body(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2983, in _run_body
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     raise
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3110, in _while_lowering_rule
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return _lower_while_via_fori(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3080, in _lower_while_via_fori
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     for_out = _lower_jaxpr_to_for_loop(
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]               ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 3015, in _lower_jaxpr_to_for_loop
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     inner_out = _run_body(iv, inner_args)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]                 ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2983, in _run_body
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     args = jaxpr_subcomp(lowering_context, jaxpr, *consts, i, *args)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     raise
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2225, in _convert_element_type_lowering_rule
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype),
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1077, in f_lowered
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     out = jaxpr_subcomp(lowering_context, jaxpr, *consts, *args)
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 1178, in jaxpr_subcomp
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     raise
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]   File "/home/qizixi/tpu-inference/.venv/lib/python3.12/site-packages/jax/_src/pallas/mosaic/lowering.py", line 2208, in _convert_element_type_lowering_rule
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843]     raise NotImplementedError(f"Unsupported cast: {old_dtype} -> {new_dtype}")
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] NotImplementedError: Unsupported cast: uint16 -> uint32
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] --------------------
(EngineCore_DP0 pid=439628) ERROR 11-20 23:44:40 [core.py:843] For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
(EngineC

@kyuyeunk
Copy link
Collaborator

ah okay. I think i know what the problem is.

I believe i used some latest feature of jax when I wrote this: #818

what is your jax version? and can you update it to the latest one?

I have verified that your PR works without error (and automatically interprets fp8 as fp8_e5m2) when I've ran the command you've pasted

$ vllm bench throughput --model meta-llama/Llama-3.1-8B --tensor-parallel-size 1 --dtype bfloat16 --kv-cache-dtype fp8 --max-model-len 4096 --max-num-seqs 128 --num-prompts 100 --dataset-name random --input-len 1024 --output-len 100
INFO 11-21 08:55:37 [__init__.py:25] TPU info: node_name=kyuyeunk-v6e-8 | tpu_type=v6e-8 | worker_id=0 | num_chips=8 | num_cores_per_chip=1
INFO 11-21 08:55:40 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-21 08:55:40 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-21 08:55:40 [interface.py:201] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
INFO 11-21 08:55:40 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-21 08:55:40 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-21 08:55:40 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-21 08:55:41 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=2048.
When dataset path is not set, it will default to random dataset
tokenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50.5k/50.5k [00:00<00:00, 5.62MB/s]
tokenizer.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.09M/9.09M [00:00<00:00, 66.9MB/s]
special_tokens_map.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 73.0/73.0 [00:00<00:00, 1.13MB/s]
INFO 11-21 08:55:42 [datasets.py:613] Sampling input_len from [1023, 1023] and output_len from [100, 100]
INFO 11-21 08:55:43 [utils.py:253] non-default args: {'tokenizer': 'meta-llama/Llama-3.1-8B', 'download_dir': '/mnt/disks/persist', 'dtype': 'bfloat16', 'kv_cache_dtype': 'fp8', 'seed': 0, 'max_model_len': 4096, 'num_redundant_experts': None, 'eplb_window_size': None, 'eplb_step_interval': None, 'eplb_log_balancedness': None, 'max_num_seqs': 128, 'enable_lora': None, 'reasoning_parser_plugin': '', 'model': 'meta-llama/Llama-3.1-8B'}
config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 826/826 [00:00<00:00, 11.1MB/s]
INFO 11-21 08:55:49 [model.py:645] Resolved architecture: LlamaForCausalLM
INFO 11-21 08:55:49 [model.py:1765] Using max model len 4096
INFO 11-21 08:55:49 [cache.py:180] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor.
INFO 11-21 08:55:49 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 11-21 08:55:49 [tpu_platform.py:127] Initialized sharding configuration: ShardingConfigManager(total_devices=1, sharding_strategy=ShardingStrategy(tensor_parallelism=1, expert_parallelism=1, sequence_parallelism=1, data_parallelism=1, attention_data_parallelism=1), device_indexes=None)
WARNING 11-21 08:55:49 [tpu_platform.py:164] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
INFO 11-21 08:55:49 [tpu_platform.py:198] Force using UniProcExecutor for JAX on                         single host without pipeline parallelism.
generation_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 185/185 [00:00<00:00, 2.30MB/s]
WARNING 11-21 08:55:50 [tpu_platform.py:239] Pin memory is not supported on TPU.
INFO 11-21 08:55:52 [__init__.py:25] TPU info: node_name=kyuyeunk-v6e-8 | tpu_type=v6e-8 | worker_id=0 | num_chips=8 | num_cores_per_chip=1
INFO 11-21 08:55:55 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-21 08:55:55 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-21 08:55:55 [interface.py:201] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
INFO 11-21 08:55:55 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-21 08:55:55 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-21 08:55:55 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
(EngineCore_DP0 pid=642275) INFO 11-21 08:55:56 [core.py:93] Initializing a V1 LLM engine (v0.11.0rc2.dev1486+gae4821a10) with config: model='meta-llama/Llama-3.1-8B', speculative_config=None, tokenizer='meta-llama/Llama-3.1-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=<class 'jax.numpy.bfloat16'>, max_seq_len=4096, download_dir='/mnt/disks/persist', load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=fp8, device_config=None, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=meta-llama/Llama-3.1-8B, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'level': None, 'mode': <CompilationMode.DYNAMO_TRACE_ONCE: 2>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'openxla', 'custom_ops': ['all'], 'splitting_ops': None, 'compile_mm_encoder': False, 'use_inductor': None, 'compile_sizes': None, 'inductor_compile_config': {'enable_auto_functionalized_v2': False}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': None, 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {}, 'max_cudagraph_capture_size': None, 'local_cache_dir': None}
(EngineCore_DP0 pid=642275) WARNING 11-21 08:55:56 [tpu_platform.py:239] Pin memory is not supported on TPU.
(EngineCore_DP0 pid=642275) WARNING 11-21 08:55:57 [tpu_worker.py:87] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [parallel_state.py:1208] world_size=1 rank=0 local_rank=0 distributed_init_method=file:///tmp/tmp_woct2fz backend=gloo
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [parallel_state.py:1394] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [tpu_runner.py:303] Init mesh | mesh=Mesh('data': 1, 'model': 1, axis_types=(Auto, Auto))
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [utils.py:93] Prepared token paddings: [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [utils.py:59] Prepared request paddings: [8, 16, 32, 64, 128]
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [compilation_manager.py:34] Enabling JAX compile cache.
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [tpu_worker.py:246] Init worker | rank=0 | node_id=0 | is_driver_worker=True | hbm=[(0.0, 31.25)]GiB
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [model_loader.py:319] Loading model with MODEL_IMPL_TYPE=flax_nnx
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:07 [weight_utils.py:119] Downloading weights from HF meta-llama/Llama-3.1-8B
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:07 [weight_utils.py:141] Loading weights from /mnt/disks/persist/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00001-of-00004.safetensors
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:07 [weight_utils.py:141] Loading weights from /mnt/disks/persist/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00002-of-00004.safetensors
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:11 [weight_utils.py:141] Loading weights from /mnt/disks/persist/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00003-of-00004.safetensors
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:11 [weight_utils.py:141] Loading weights from /mnt/disks/persist/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00004-of-00004.safetensors
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:13 [tpu_runner.py:527] Init model | hbm=[(14.96, 31.25)]GiB
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:13 [tpu_worker.py:275] Memory statistics | total_hbm_limit_gb=31.25GiB | total_hbm_limit_cap_gb=28.12GiB | total_hbm_used_gb=14.96GiB | total_hbm_avail_gb=13.16GiB
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:13 [kv_cache_utils.py:1229] GPU KV cache size: 215,552 tokens
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:13 [kv_cache_utils.py:1234] Maximum concurrency for 4,096 tokens per request: 52.62x
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:15 [kv_cache_manager.py:215] Init kv-cache | num_layers=32 | shape=(num_blocks, (256, 4, 4, 128)) | num_blocks=[842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842] | sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('data', None, 'model'), memory_kind=device) | dtype=float8_e4m3fn | hbm=[(28.12, 31.25)]Gb
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:15 [compilation_manager.py:73] Precompile all the subgraphs with possible input shapes.
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:15 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 16}
(EngineCore_DP0 pid=642275) WARNING 11-21 08:57:15 [tuned_block_sizes.py:4077] Couldn`t find tuned sizes for the RPA v3 kernel with ('TPU v6e', 256, 'q_bfloat16_kv_float8_e4m3fn', 'q_head-32_kv_head-8_head-128', 4096)
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:21 [compilation_manager.py:67] Compilation finished in 6.71 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:21 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 32}
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:29 [compilation_manager.py:67] Compilation finished in 7.58 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:29 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 64}
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:37 [compilation_manager.py:67] Compilation finished in 8.07 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:37 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 128}
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:45 [compilation_manager.py:67] Compilation finished in 7.89 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:45 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 256}
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:54 [compilation_manager.py:67] Compilation finished in 8.62 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:54 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 512}
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:02 [compilation_manager.py:67] Compilation finished in 8.57 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:02 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 1024}
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:12 [compilation_manager.py:67] Compilation finished in 9.42 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:12 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 2048}
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:21 [compilation_manager.py:67] Compilation finished in 9.53 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:21 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 4096}
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:31 [compilation_manager.py:67] Compilation finished in 9.68 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:31 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 8192}
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:43 [compilation_manager.py:67] Compilation finished in 11.59 [secs].

@zixi-qi
Copy link
Author

zixi-qi commented Nov 21, 2025

ah okay. I think i know what the problem is.

I believe i used some latest feature of jax when I wrote this: #818

what is your jax version? and can you update it to the latest one?

I have verified that your PR works without error (and automatically interprets fp8 as fp8_e5m2) when I've ran the command you've pasted

$ vllm bench throughput --model meta-llama/Llama-3.1-8B --tensor-parallel-size 1 --dtype bfloat16 --kv-cache-dtype fp8 --max-model-len 4096 --max-num-seqs 128 --num-prompts 100 --dataset-name random --input-len 1024 --output-len 100
INFO 11-21 08:55:37 [__init__.py:25] TPU info: node_name=kyuyeunk-v6e-8 | tpu_type=v6e-8 | worker_id=0 | num_chips=8 | num_cores_per_chip=1
INFO 11-21 08:55:40 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-21 08:55:40 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-21 08:55:40 [interface.py:201] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
INFO 11-21 08:55:40 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-21 08:55:40 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-21 08:55:40 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-21 08:55:41 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=2048.
When dataset path is not set, it will default to random dataset
tokenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50.5k/50.5k [00:00<00:00, 5.62MB/s]
tokenizer.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9.09M/9.09M [00:00<00:00, 66.9MB/s]
special_tokens_map.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 73.0/73.0 [00:00<00:00, 1.13MB/s]
INFO 11-21 08:55:42 [datasets.py:613] Sampling input_len from [1023, 1023] and output_len from [100, 100]
INFO 11-21 08:55:43 [utils.py:253] non-default args: {'tokenizer': 'meta-llama/Llama-3.1-8B', 'download_dir': '/mnt/disks/persist', 'dtype': 'bfloat16', 'kv_cache_dtype': 'fp8', 'seed': 0, 'max_model_len': 4096, 'num_redundant_experts': None, 'eplb_window_size': None, 'eplb_step_interval': None, 'eplb_log_balancedness': None, 'max_num_seqs': 128, 'enable_lora': None, 'reasoning_parser_plugin': '', 'model': 'meta-llama/Llama-3.1-8B'}
config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 826/826 [00:00<00:00, 11.1MB/s]
INFO 11-21 08:55:49 [model.py:645] Resolved architecture: LlamaForCausalLM
INFO 11-21 08:55:49 [model.py:1765] Using max model len 4096
INFO 11-21 08:55:49 [cache.py:180] Using fp8 data type to store kv cache. It reduces the GPU memory footprint and boosts the performance. Meanwhile, it may cause accuracy drop without a proper scaling factor.
INFO 11-21 08:55:49 [scheduler.py:216] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 11-21 08:55:49 [tpu_platform.py:127] Initialized sharding configuration: ShardingConfigManager(total_devices=1, sharding_strategy=ShardingStrategy(tensor_parallelism=1, expert_parallelism=1, sequence_parallelism=1, data_parallelism=1, attention_data_parallelism=1), device_indexes=None)
WARNING 11-21 08:55:49 [tpu_platform.py:164] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
INFO 11-21 08:55:49 [tpu_platform.py:198] Force using UniProcExecutor for JAX on                         single host without pipeline parallelism.
generation_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 185/185 [00:00<00:00, 2.30MB/s]
WARNING 11-21 08:55:50 [tpu_platform.py:239] Pin memory is not supported on TPU.
INFO 11-21 08:55:52 [__init__.py:25] TPU info: node_name=kyuyeunk-v6e-8 | tpu_type=v6e-8 | worker_id=0 | num_chips=8 | num_cores_per_chip=1
INFO 11-21 08:55:55 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 11-21 08:55:55 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
WARNING 11-21 08:55:55 [interface.py:201] Failed to import from vllm._C: ModuleNotFoundError("No module named 'vllm._C'")
INFO 11-21 08:55:55 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-21 08:55:55 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
INFO 11-21 08:55:55 [tpu_platform.py:89] Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.
(EngineCore_DP0 pid=642275) INFO 11-21 08:55:56 [core.py:93] Initializing a V1 LLM engine (v0.11.0rc2.dev1486+gae4821a10) with config: model='meta-llama/Llama-3.1-8B', speculative_config=None, tokenizer='meta-llama/Llama-3.1-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=<class 'jax.numpy.bfloat16'>, max_seq_len=4096, download_dir='/mnt/disks/persist', load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=fp8, device_config=None, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=meta-llama/Llama-3.1-8B, enable_prefix_caching=True, enable_chunked_prefill=True, pooler_config=None, compilation_config={'level': None, 'mode': <CompilationMode.DYNAMO_TRACE_ONCE: 2>, 'debug_dump_path': None, 'cache_dir': '', 'compile_cache_save_format': 'binary', 'backend': 'openxla', 'custom_ops': ['all'], 'splitting_ops': None, 'compile_mm_encoder': False, 'use_inductor': None, 'compile_sizes': None, 'inductor_compile_config': {'enable_auto_functionalized_v2': False}, 'inductor_passes': {}, 'cudagraph_mode': <CUDAGraphMode.NONE: 0>, 'cudagraph_num_of_warmups': 0, 'cudagraph_capture_sizes': None, 'cudagraph_copy_inputs': False, 'cudagraph_specialize_lora': True, 'use_inductor_graph_partition': False, 'pass_config': {}, 'max_cudagraph_capture_size': None, 'local_cache_dir': None}
(EngineCore_DP0 pid=642275) WARNING 11-21 08:55:56 [tpu_platform.py:239] Pin memory is not supported on TPU.
(EngineCore_DP0 pid=642275) WARNING 11-21 08:55:57 [tpu_worker.py:87] The model dtype is not properly set for JAX backend. Overwriting it to jnp.bfloat16
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [parallel_state.py:1208] world_size=1 rank=0 local_rank=0 distributed_init_method=file:///tmp/tmp_woct2fz backend=gloo
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [parallel_state.py:1394] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [tpu_runner.py:303] Init mesh | mesh=Mesh('data': 1, 'model': 1, axis_types=(Auto, Auto))
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [utils.py:93] Prepared token paddings: [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [utils.py:59] Prepared request paddings: [8, 16, 32, 64, 128]
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [compilation_manager.py:34] Enabling JAX compile cache.
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [tpu_worker.py:246] Init worker | rank=0 | node_id=0 | is_driver_worker=True | hbm=[(0.0, 31.25)]GiB
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:06 [model_loader.py:319] Loading model with MODEL_IMPL_TYPE=flax_nnx
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:07 [weight_utils.py:119] Downloading weights from HF meta-llama/Llama-3.1-8B
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:07 [weight_utils.py:141] Loading weights from /mnt/disks/persist/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00001-of-00004.safetensors
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:07 [weight_utils.py:141] Loading weights from /mnt/disks/persist/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00002-of-00004.safetensors
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:11 [weight_utils.py:141] Loading weights from /mnt/disks/persist/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00003-of-00004.safetensors
(EngineCore_DP0 pid=642275) INFO 11-21 08:56:11 [weight_utils.py:141] Loading weights from /mnt/disks/persist/models--meta-llama--Llama-3.1-8B/snapshots/d04e592bb4f6aa9cfee91e2e20afa771667e1d4b/model-00004-of-00004.safetensors
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:13 [tpu_runner.py:527] Init model | hbm=[(14.96, 31.25)]GiB
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:13 [tpu_worker.py:275] Memory statistics | total_hbm_limit_gb=31.25GiB | total_hbm_limit_cap_gb=28.12GiB | total_hbm_used_gb=14.96GiB | total_hbm_avail_gb=13.16GiB
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:13 [kv_cache_utils.py:1229] GPU KV cache size: 215,552 tokens
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:13 [kv_cache_utils.py:1234] Maximum concurrency for 4,096 tokens per request: 52.62x
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:15 [kv_cache_manager.py:215] Init kv-cache | num_layers=32 | shape=(num_blocks, (256, 4, 4, 128)) | num_blocks=[842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842, 842] | sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 1, axis_types=(Auto, Auto)), spec=PartitionSpec('data', None, 'model'), memory_kind=device) | dtype=float8_e4m3fn | hbm=[(28.12, 31.25)]Gb
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:15 [compilation_manager.py:73] Precompile all the subgraphs with possible input shapes.
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:15 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 16}
(EngineCore_DP0 pid=642275) WARNING 11-21 08:57:15 [tuned_block_sizes.py:4077] Couldn`t find tuned sizes for the RPA v3 kernel with ('TPU v6e', 256, 'q_bfloat16_kv_float8_e4m3fn', 'q_head-32_kv_head-8_head-128', 4096)
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:21 [compilation_manager.py:67] Compilation finished in 6.71 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:21 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 32}
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:29 [compilation_manager.py:67] Compilation finished in 7.58 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:29 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 64}
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:37 [compilation_manager.py:67] Compilation finished in 8.07 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:37 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 128}
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:45 [compilation_manager.py:67] Compilation finished in 7.89 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:45 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 256}
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:54 [compilation_manager.py:67] Compilation finished in 8.62 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:57:54 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 512}
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:02 [compilation_manager.py:67] Compilation finished in 8.57 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:02 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 1024}
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:12 [compilation_manager.py:67] Compilation finished in 9.42 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:12 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 2048}
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:21 [compilation_manager.py:67] Compilation finished in 9.53 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:21 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 4096}
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:31 [compilation_manager.py:67] Compilation finished in 9.68 [secs].
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:31 [compilation_manager.py:57] Precompile backbone --> {'num_tokens': 8192}
(EngineCore_DP0 pid=642275) INFO 11-21 08:58:43 [compilation_manager.py:67] Compilation finished in 11.59 [secs].

Yeah seems this is the issue thanks for debugging this! Now e2e passes

vllm bench throughput --model meta-llama/Llama-3.1-8B --tensor-parallel-size 1 --dtype bfloat16 --kv-cache-dtype fp8 --max-model-len 4096 --max-num-seqs 128 --num-prompts 100 --dataset-name random --input-len 1024 --output-len 100

INFO 11-21 18:17:43 [llm.py:352] Supported tasks: ['generate']
Adding requests: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 610.10it/s]
Processed prompts: 100%|██████████████████████████████████████████████████| 100/100 [00:05<00:00, 17.45it/s, est. speed input: 17870.47 toks/s, output: 1745.16 toks/s]
Throughput: 16.96 requests/s, 19066.22 total tokens/s, 1696.28 output tokens/s
Total num prompt tokens:  102400
Total num output tokens:  10000

Copy link
Collaborator

@kyuyeunk kyuyeunk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. but will require our ci to pass before being merged.

i'll run it manually and get back to you when i have the results (might take some time. like few days at most ¯\(ツ)/¯)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants