Skip to content

Commit c48b590

Browse files
committed
Deprecate JAX_RANDOM_WEIGHTS
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 0c66fde commit c48b590

File tree

7 files changed

+18
-19
lines changed

7 files changed

+18
-19
lines changed

.buildkite/pipeline_jax.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,12 @@ steps:
190190
USE_V6E8_QUEUE: "True"
191191
SKIP_ACCURACY_TESTS: "True"
192192
VLLM_MLA_DISABLE: "1"
193-
JAX_RANDOM_WEIGHTS: "True"
194193
agents:
195194
queue: tpu_v6e_8_queue
196195
commands:
197196
- |
198197
if [[ "$$NIGHTLY" == "1" ]]; then
199-
.buildkite/scripts/run_in_docker.sh bash /workspace/tpu_inference/tests/e2e/benchmarking/mlperf.sh -m deepseek-ai/DeepSeek-R1-0528
198+
.buildkite/scripts/run_in_docker.sh bash /workspace/tpu_inference/tests/e2e/benchmarking/mlperf.sh -m deepseek-ai/DeepSeek-R1-0528 --dummy
200199
else
201200
echo "Skipping: NIGHTLY environment variable not set"
202201
exit 0

.buildkite/scripts/run_in_docker.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ exec docker run \
108108
${QUANTIZATION:+-e QUANTIZATION="$QUANTIZATION"} \
109109
${NEW_MODEL_DESIGN:+-e NEW_MODEL_DESIGN="$NEW_MODEL_DESIGN"} \
110110
${USE_V6E8_QUEUE:+-e USE_V6E8_QUEUE="$USE_V6E8_QUEUE"} \
111-
${JAX_RANDOM_WEIGHTS:+-e JAX_RANDOM_WEIGHTS="$JAX_RANDOM_WEIGHTS"} \
112111
${SKIP_ACCURACY_TESTS:+-e SKIP_ACCURACY_TESTS="$SKIP_ACCURACY_TESTS"} \
113112
${VLLM_MLA_DISABLE:+-e VLLM_MLA_DISABLE="$VLLM_MLA_DISABLE"} \
114113
"${IMAGE_NAME}:${BUILDKITE_COMMIT}" \

tests/e2e/benchmarking/mlperf.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dataset_name=mlperf
4747
dataset_path=""
4848
num_prompts=1000
4949
exit_code=0
50+
dummy=false
5051

5152
helpFunction()
5253
{
@@ -57,6 +58,7 @@ helpFunction()
5758
echo -e "\t-p The path to the processed MLPerf dataset (default: None, which will download the dataset)"
5859
echo -e "\t-m A space-separated list of HuggingFace model ids to use (default: Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-0.5B-Instruct, meta-llama/Llama-3.1-8B-Instruct and meta-llama/Llama-4-Scout-17B-16E-Instruct)"
5960
echo -e "\t-n Number of prompts to use for the benchmark (default: 10)"
61+
echo -e "\t--dummy Use dummy random weight (default: false)"
6062
exit 1
6163
}
6264

@@ -87,6 +89,11 @@ while [[ "$#" -gt 0 ]]; do
8789
shift
8890
shift
8991
;;
92+
--dummy)
93+
dummy=true
94+
shift
95+
shift
96+
;;
9097
-h|--help)
9198
helpFunction
9299
;;
@@ -121,6 +128,10 @@ if [ -z "$dataset_path" ]; then
121128
fi
122129
fi
123130

131+
if [ "$dummy" = true ]; then
132+
extra_serve_args+=("--load-format=dummy")
133+
fi
134+
124135
echo "Using the dataset at $dataset_path"
125136

126137
cd "$root_dir"/vllm || exit

tests/models/common/test_model_loader.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,17 +254,13 @@ def test_get_vllm_model(mesh):
254254
assert callable(compute_logits_fn)
255255

256256

257-
@pytest.mark.parametrize("set_in_config", [True, False])
258-
def test_get_vllm_model_random_weights(mesh, set_in_config):
257+
def test_get_vllm_model_random_weights(mesh):
259258
rng = jax.random.PRNGKey(42)
260259

261260
engine_args = EngineArgs(model="Qwen/Qwen3-0.6B")
262261
vllm_config = engine_args.create_engine_config()
263262
vllm_config.model_config.dtype = torch.bfloat16
264-
if set_in_config:
265-
vllm_config.load_config.load_format = "dummy"
266-
else:
267-
os.environ["JAX_RANDOM_WEIGHTS"] = "True"
263+
vllm_config.load_config.load_format = "dummy"
268264

269265
with set_current_vllm_config(vllm_config):
270266
temp_file = tempfile.mkstemp()[1]

tpu_inference/models/common/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def create_jit_model(
103103
apply_to_abstract_model=False)
104104
return model
105105

106-
if os.getenv("JAX_RANDOM_WEIGHTS", False):
106+
if vllm_config.load_config.load_format == "dummy":
107107
# Create a sharded model with random inited weights.
108108
# TODO: currently Qwen2ForCausalLM is using legacy model implementation
109109
# will merge the random init logic when all model are migrated to new model implementation

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import functools
3-
import os
43
from collections.abc import Sequence
54
from contextlib import nullcontext
65
from typing import Any, List, Optional, Tuple
@@ -91,12 +90,8 @@ def load_weights(self):
9190
# may casue errors. Therefore, we disable it during weight loading.
9291
vllm_config_for_load.parallel_config.enable_expert_parallel = False
9392

94-
if os.getenv("JAX_RANDOM_WEIGHTS", False):
95-
vllm_config_for_load.load_config.load_format = "dummy"
96-
use_random_weights = True
97-
else:
98-
use_random_weights = (
99-
vllm_config_for_load.load_config.load_format == "dummy")
93+
use_random_weights = (
94+
vllm_config_for_load.load_config.load_format == "dummy")
10095
if use_random_weights:
10196
logger.info(
10297
"Initializing vLLM model with random weights, weight loading skipped."

tpu_inference/platforms/tpu_jax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ class TpuPlatform(Platform):
4848
]
4949

5050
additional_env_vars: list[str] = [
51-
"JAX_RANDOM_WEIGHTS", "PHASED_PROFILING_DIR",
52-
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
51+
"PHASED_PROFILING_DIR", "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS",
5352
"TPU_MULTIHOST_BACKEND", "VLLM_MLA_DISABLE", "TPU_BACKEND_TYPE"
5453
]
5554

0 commit comments

Comments
 (0)