Skip to content

Commit 12e873f

Browse files
committed
Remove SKIP_JAX_PRECOMPILE
- Same functionality can be achieve with vllm argument --enforce-eager - It is better to remove duplicate configs to avoid confusion from users Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 7227930 commit 12e873f

File tree

11 files changed

+15
-32
lines changed

11 files changed

+15
-32
lines changed

.buildkite/models/google_gemma-3-27b-it.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ steps:
99
commands:
1010
- |
1111
.buildkite/scripts/run_in_docker.sh \
12-
bash -c 'SKIP_JAX_PRECOMPILE=1 VLLM_XLA_CHECK_RECOMPILATION=0 python3 /workspace/tpu_inference/examples/offline_inference.py --model=google/gemma-3-27b-it --tensor_parallel_size=8 --task=generate --max_model_len=1024 --max_num_seqs=1'
12+
bash -c 'VLLM_XLA_CHECK_RECOMPILATION=0 python3 /workspace/tpu_inference/examples/offline_inference.py --model=google/gemma-3-27b-it --tensor_parallel_size=8 --task=generate --max_model_len=1024 --max_num_seqs=1'
1313
- label: "Record unit test result for google/gemma-3-27b-it"
1414
key: "record_google_gemma-3-27b-it_UnitTest"
1515
depends_on: "google_gemma-3-27b-it_UnitTest"

examples/disagg/run_disagg_multi_host.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ for ((i=0; i<NUM_HOSTS_PER_INSTANCE; i++)); do
6363
-e TPU_KV_TRANSFER_PORT="${KV_PORT}" \
6464
-e TPU_SIDE_CHANNEL_PORT="${SIDE_PORT}" \
6565
-e RAY_DEDUP_LOGS="0" \
66-
-e SKIP_JAX_PRECOMPILE="1" \
6766
\
6867
-e TPU_CHIPS_PER_PROCESS_BOUNDS="1,1,1" \
6968
-e TPU_PROCESS_BOUNDS="2,2,1" \
@@ -95,6 +94,7 @@ docker exec node-0 /bin/bash -c \
9594
--gpu-memory-utilization 0.3 \
9695
--tensor-parallel-size 4 \
9796
--kv-transfer-config '{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_producer\"}' \
97+
--enforce-eager \
9898
> /root/logs/prefill.txt 2>&1 &"
9999
set +x
100100

@@ -137,7 +137,6 @@ for ((i=0; i<NUM_HOSTS_PER_INSTANCE; i++)); do
137137
-e TPU_KV_TRANSFER_PORT="${KV_PORT}" \
138138
-e TPU_SIDE_CHANNEL_PORT="${SIDE_PORT}" \
139139
-e RAY_DEDUP_LOGS="0" \
140-
-e SKIP_JAX_PRECOMPILE="1" \
141140
\
142141
-e TPU_CHIPS_PER_PROCESS_BOUNDS="1,1,1" \
143142
-e TPU_PROCESS_BOUNDS="2,2,1" \
@@ -169,5 +168,6 @@ docker exec node-20 /bin/bash -c \
169168
--gpu-memory-utilization 0.3 \
170169
--tensor-parallel-size 4 \
171170
--kv-transfer-config '{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_consumer\"}' \
171+
--enforce-eager \
172172
> /root/logs/decode.txt 2>&1 &"
173173
set +x

examples/disagg/run_disagg_single_host.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ for i in $(seq 0 $((NUM_PREFILL_INSTANCES-1))); do
4545
\
4646
TPU_KV_TRANSFER_PORT=$KV_PORT \
4747
TPU_SIDE_CHANNEL_PORT=$SIDE_PORT \
48-
SKIP_JAX_PRECOMPILE=1 \
4948
\
5049
vllm serve $MODEL \
5150
--port $PORT \
5251
--gpu-memory-utilization 0.2 \
5352
--tensor-parallel-size $PREFILLER_TP_SIZE \
5453
--kv-transfer-config "{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_producer\"}" \
54+
--enforce-eager \
5555
> $HOME/logs/prefill_$i.txt 2>&1 &
5656

5757
PREFILL_HOSTS+=("localhost")
@@ -72,13 +72,13 @@ for i in $(seq 0 $((NUM_DECODE_INSTANCES-1))); do
7272
\
7373
TPU_KV_TRANSFER_PORT=$KV_PORT \
7474
TPU_SIDE_CHANNEL_PORT=$SIDE_PORT \
75-
SKIP_JAX_PRECOMPILE=1 \
7675
\
7776
vllm serve $MODEL \
7877
--port $PORT \
7978
--gpu-memory-utilization 0.2 \
8079
--tensor-parallel-size $DECODER_TP_SIZE \
8180
--kv-transfer-config "{\"kv_connector\":\"TPUConnector\",\"kv_connector_module_path\":\"tpu_inference.distributed.tpu_connector\",\"kv_role\":\"kv_consumer\"}" \
81+
--enforce-eager \
8282
> $HOME/logs/decode_$i.txt 2>&1 &
8383

8484
DECODE_HOSTS+=("localhost")

examples/offline_inference.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import os
5-
64
import vllm.envs as envs
75
from vllm import LLM, EngineArgs
86
from vllm.utils.argparse_utils import FlexibleArgumentParser
@@ -17,6 +15,9 @@ def create_parser():
1715
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
1816
parser.set_defaults(max_model_len=1024)
1917

18+
# Skip long warmup for local simple test.
19+
parser.set_defaults(enforce_eager=True)
20+
2021
# Add sampling params
2122
sampling_group = parser.add_argument_group("Sampling parameters")
2223
sampling_group.add_argument("--max-tokens", type=int)
@@ -103,9 +104,6 @@ def main(args: dict):
103104

104105

105106
if __name__ == "__main__":
106-
# Skip long warmup for local simple test.
107-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
108-
109107
parser = create_parser()
110108
args: dict = vars(parser.parse_args())
111109

examples/offline_lora_inference.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
import os
54
import time
65

76
import vllm.envs as envs
@@ -20,6 +19,9 @@ def create_parser():
2019
parser.set_defaults(enable_lora=True)
2120
parser.set_defaults(max_lora_rank=8)
2221

22+
# Skip long warmup for local simple test.
23+
parser.set_defaults(enforce_eager=True)
24+
2325
# Add sampling params
2426
sampling_group = parser.add_argument_group("Sampling parameters")
2527
sampling_group.add_argument("--max-tokens", type=int, default=16)
@@ -76,9 +78,6 @@ def main(args: dict):
7678

7779

7880
if __name__ == "__main__":
79-
# Skip long warmup for local simple test.
80-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
81-
8281
parser = create_parser()
8382
args: dict = vars(parser.parse_args())
8483

tests/e2e/benchmarking/mm_bench.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ checkThroughputAndRouge() {
9191
}
9292

9393
echo "Spinning up the vLLM server..."
94-
(SKIP_JAX_PRECOMPILE=1 VLLM_XLA_CHECK_RECOMPILATION=0 vllm serve "$model_name" --max-model-len "$max_model_len" --max-num-seqs "$max_num_seqs" --disable-log-requests --max-num-batched-tokens "$max_batched_tokens" 2>&1 | tee -a "$LOG_FILE") &
94+
(VLLM_XLA_CHECK_RECOMPILATION=0 vllm serve "$model_name" --max-model-len "$max_model_len" --max-num-seqs "$max_num_seqs" --disable-log-requests --max-num-batched-tokens "$max_batched_tokens" --enforce-eager 2>&1 | tee -a "$LOG_FILE") &
9595

9696

9797
# Run a busy loop to block until the server is ready to receive requests

tests/e2e/test_data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def _run_inference_with_config(model_name: str,
6767
additional_config=additional_config,
6868
kv_cache_dtype=kv_cache_dtype,
6969
async_scheduling=async_scheduling,
70+
enforce_eager=True,
7071
)
7172

7273
engine_args_dict = asdict(engine_args)
@@ -173,7 +174,6 @@ def test_data_parallelism_correctness(
173174
This test compares outputs from a single-device run with data parallel runs
174175
to ensure correctness, including log probabilities.
175176
"""
176-
os.environ['SKIP_JAX_PRECOMPILE'] = '1'
177177
os.environ['VLLM_XLA_CHECK_RECOMPILATION'] = '0'
178178
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
179179
# Use a smaller subset of prompts for correctness testing

tests/e2e/test_multi_modal_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def test_multi_modal_inference(monkeypatch):
2424
"""
2525
Runs multi-modal inference and verifies the output.
2626
"""
27-
os.environ['SKIP_JAX_PRECOMPILE'] = '1' # Skip warmup to save time.
2827
os.environ[
2928
'VLLM_XLA_CHECK_RECOMPILATION'] = '0' # Allow compilation during execution.
3029

@@ -65,6 +64,7 @@ def test_multi_modal_inference(monkeypatch):
6564
"fps": 1,
6665
},
6766
limit_mm_per_prompt={modality: 1},
67+
enforce_eager=True, # Skip warmup to save time.
6868
)
6969
engine_args = asdict(engine_args)
7070
llm = LLM(**engine_args)

tests/test_envs.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,6 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
5656

5757

5858
def test_boolean_env_vars(monkeypatch: pytest.MonkeyPatch):
59-
# Test SKIP_JAX_PRECOMPILE (default False)
60-
assert envs.SKIP_JAX_PRECOMPILE is False
61-
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "1")
62-
assert envs.SKIP_JAX_PRECOMPILE is True
63-
monkeypatch.setenv("SKIP_JAX_PRECOMPILE", "0")
64-
assert envs.SKIP_JAX_PRECOMPILE is False
65-
6659
# Test NEW_MODEL_DESIGN (default False)
6760
assert envs.NEW_MODEL_DESIGN is False
6861
monkeypatch.setenv("NEW_MODEL_DESIGN", "1")
@@ -133,7 +126,6 @@ def test_dir_returns_all_env_vars():
133126
assert len(env_vars) == len(environment_variables)
134127
assert "JAX_PLATFORMS" in env_vars
135128
assert "TPU_NAME" in env_vars
136-
assert "SKIP_JAX_PRECOMPILE" in env_vars
137129
assert "MODEL_IMPL_TYPE" in env_vars
138130

139131

tpu_inference/envs.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
TPU_MULTIHOST_BACKEND: str = ""
1515
PREFILL_SLICES: str = ""
1616
DECODE_SLICES: str = ""
17-
SKIP_JAX_PRECOMPILE: bool = False
1817
MODEL_IMPL_TYPE: str = "flax_nnx"
1918
NEW_MODEL_DESIGN: bool = False
2019
PHASED_PROFILING_DIR: str = ""
@@ -45,9 +44,6 @@
4544
# Slice configuration for disaggregated decode workers
4645
"DECODE_SLICES":
4746
lambda: os.getenv("DECODE_SLICES", ""),
48-
# Skip JAX precompilation step during initialization
49-
"SKIP_JAX_PRECOMPILE":
50-
lambda: bool(int(os.getenv("SKIP_JAX_PRECOMPILE", "0"))),
5147
# Model implementation type (e.g., "flax_nnx")
5248
"MODEL_IMPL_TYPE":
5349
lambda: os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower(),

0 commit comments

Comments
 (0)