-
Notifications
You must be signed in to change notification settings - Fork 45
[DistInf] Enable SGLang disagg with MoRI-io & MoRI-EP for all existing models #154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lcskrishna
wants to merge
1
commit into
ROCm:develop
Choose a base branch
from
lcskrishna:csrikris-pr222
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| # Model catalog for sglang_disagg_mori_io_ep.sh (MoRI IO transfer + disaggregated PD). | ||
| # ----------------------------------------------------------------------------- | ||
| # Loaded when MODEL_NAME is set; PARALLEL_MODE is tp or dp from DP_MODE in the shell: | ||
| # DP_MODE=0 -> PARALLEL_MODE=tp -> tp_flags, prefill.tp, decode.tp (plus base_flags). Use this for all models except when using DP_MODE=1 below. | ||
| # DP_MODE=1 -> PARALLEL_MODE=dp -> dp_flags, prefill.dp, decode.dp (plus base_flags). Allowlist MORI_DP_MODE1_ALLOWED_MODELS in sglang_disagg_mori_io_ep.sh (same array in run_xPyD_models.slurm). | ||
| # Do not set --tp-size here; the launcher passes it (GENERIC_TP_SIZE=8 for DP_MODE=0). | ||
| # | ||
| # Each model contains: | ||
| # - base_flags: always applied (prefill + decode) | ||
| # - tp_flags / dp_flags: mode-level flags (dp_flags for DP_MODE=1 only on DeepSeek-V3 / DeepSeek-R1; other models use DP_MODE=0 only) | ||
| # - prefill/decode.<tp|dp>: role + mode specific flags | ||
| # - experimental_flags: optional extra CLI flags appended after role/mode flags on BOTH prefill and decode workers. | ||
| # Use for try-outs that are independent of TP vs DP (PARALLEL_MODE); omit or "" when unused. | ||
|
|
||
| # --- Dense Llama / Qwen (lighter defaults; tune after first OOM on cluster) --- | ||
|
|
||
| Llama-3.1-8B-Instruct: | ||
| base_flags: "--attention-backend aiter --watchdog-timeout 1000000 --mem-fraction-static 0.85 --disable-radix-cache" | ||
| experimental_flags: "" | ||
| tp_flags: "" | ||
| dp_flags: "" | ||
| prefill: | ||
| tp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 4)" | ||
| dp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 4)" | ||
| decode: | ||
| tp: "--max-running-requests 4096 --cuda-graph-bs $(seq 1 8)" | ||
| dp: "--max-running-requests 4096 --cuda-graph-bs $(seq 1 8)" | ||
|
|
||
| Qwen3-32B: | ||
| base_flags: "--attention-backend aiter --watchdog-timeout 1000000 --mem-fraction-static 0.82 --disable-radix-cache" | ||
| experimental_flags: "" | ||
| tp_flags: "" | ||
| dp_flags: "" | ||
| prefill: | ||
| tp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 3)" | ||
| dp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 3)" | ||
| decode: | ||
| tp: "--max-running-requests 8192 --cuda-graph-bs $(seq 1 8)" | ||
| dp: "--max-running-requests 8192 --cuda-graph-bs $(seq 1 8)" | ||
|
|
||
| amd-Llama-3.3-70B-Instruct-FP8-KV: | ||
| base_flags: "--attention-backend aiter --watchdog-timeout 1000000 --mem-fraction-static 0.8 --disable-radix-cache" | ||
| experimental_flags: "" | ||
| tp_flags: "" | ||
| dp_flags: "" | ||
| prefill: | ||
| tp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 3)" | ||
| dp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 3)" | ||
| decode: | ||
| tp: "--max-running-requests 8192 --cuda-graph-bs $(seq 1 8)" | ||
| dp: "--max-running-requests 8192 --cuda-graph-bs $(seq 1 8)" | ||
|
|
||
| Llama-3.1-405B-Instruct-FP8-KV: | ||
| base_flags: "--attention-backend aiter --watchdog-timeout 1000000 --mem-fraction-static 0.78 --disable-radix-cache" | ||
| experimental_flags: "" | ||
| tp_flags: "" | ||
| dp_flags: "" | ||
| prefill: | ||
| tp: "--max-running-requests 4 --cuda-graph-bs $(seq 1 2)" | ||
| dp: "--max-running-requests 4 --cuda-graph-bs $(seq 1 2)" | ||
| decode: | ||
| tp: "--max-running-requests 2048 --cuda-graph-bs $(seq 1 6)" | ||
| dp: "--max-running-requests 2048 --cuda-graph-bs $(seq 1 6)" | ||
|
|
||
| # --- MoE --- | ||
|
|
||
| Mixtral-8x7B-Instruct-v0.1: | ||
| base_flags: "--attention-backend aiter --watchdog-timeout 1000000 --mem-fraction-static 0.8 --disable-radix-cache" | ||
| experimental_flags: "" | ||
| tp_flags: "" | ||
| dp_flags: "" | ||
| prefill: | ||
| tp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 3)" | ||
| dp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 3)" | ||
| decode: | ||
| tp: "--max-running-requests 8192 --cuda-graph-bs $(seq 1 8)" | ||
| dp: "--max-running-requests 8192 --cuda-graph-bs $(seq 1 8)" | ||
|
|
||
| DeepSeek-V3: | ||
| base_flags: "--attention-backend aiter --watchdog-timeout 1000000 --mem-fraction-static 0.8 --disable-radix-cache" | ||
| experimental_flags: "" | ||
| tp_flags: "" | ||
| dp_flags: "--moe-a2a-backend mori --enable-dp-attention --enable-dp-lm-head --moe-dense-tp-size 1 " | ||
| prefill: | ||
| tp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 3)" | ||
| dp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 3)" | ||
| decode: | ||
| tp: "--max-running-requests 8192 --cuda-graph-bs $(seq 1 8)" | ||
| dp: "--max-running-requests 8192 --cuda-graph-bs $(seq 1 8)" | ||
|
|
||
| DeepSeek-R1: | ||
| base_flags: "--attention-backend aiter --watchdog-timeout 1000000 --mem-fraction-static 0.8 --disable-radix-cache" | ||
| experimental_flags: "" | ||
| tp_flags: "" | ||
| dp_flags: "--moe-a2a-backend mori --enable-dp-attention --enable-dp-lm-head --moe-dense-tp-size 1 " | ||
| prefill: | ||
| tp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 3)" | ||
| dp: "--max-running-requests 8 --cuda-graph-bs $(seq 1 3)" | ||
| decode: | ||
| tp: "--max-running-requests 8192 --cuda-graph-bs $(seq 1 8)" | ||
| dp: "--max-running-requests 8192 --cuda-graph-bs $(seq 1 8)" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| #!/bin/bash | ||
| # MoRI / SGLang multi-node defaults aligned with SemiAnalysis InferenceX amd_utils. | ||
| # Upstream reference: https://github.com/SemiAnalysisAI/InferenceX/blob/main/benchmarks/multi_node/amd_utils/env.sh | ||
| # | ||
| # Expects MODEL_NAME to be set when sourced (for mxfp4-related toggles). | ||
| # IBDEVICES: prefer pre-set value; else reuse IB_DEVICES from the caller; else a safe default. | ||
|
|
||
| export PYTHONDONTWRITEBYTECODE="${PYTHONDONTWRITEBYTECODE:-1}" | ||
|
|
||
| export IB_DEVICES="${IB_DEVICES:-mlx5_0,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_7,mlx5_8,mlx5_9}" | ||
|
|
||
| # NCCL / RCCL IB HCAs (same list as disaggregation IB device list in typical setups) | ||
| export NCCL_IB_HCA="${NCCL_IB_HCA:-$IB_DEVICES}" | ||
|
|
||
| # Socket NIC: honor IFNAME / existing exports; else default route (InferenceX-style) | ||
| export GLOO_SOCKET_IFNAME=${GLOO_SOCKET_IFNAME:-$(ip route | grep '^default' | awk '{print $NF}' | head -n 1)} | ||
| export NCCL_SOCKET_IFNAME=${NCCL_SOCKET_IFNAME:-${GLOO_SOCKET_IFNAME}} | ||
|
|
||
| export SGLANG_USE_AITER="${SGLANG_USE_AITER:-1}" | ||
| export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT="${SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT:-1200}" | ||
| export SGLANG_DISAGGREGATION_WAITING_TIMEOUT="${SGLANG_DISAGGREGATION_WAITING_TIMEOUT:-1200}" | ||
|
|
||
| # MoRI / MoE EP tuning: only apply when DP_MODE=1 (data-parallel MoRI path). | ||
| if [[ "${DP_MODE:-}" == "1" ]]; then | ||
| #export MORI_SHMEM_MODE="${MORI_SHMEM_MODE:-ISOLATION}" | ||
|
|
||
| # Symmetric heap for mori.shmem (MoE EP scales with world size / tokens). Mori default ~4G is often too small | ||
| # for multi-GPU or multi-node; errors look like "Out of static heap memory" / HIP invalid argument after OOM. | ||
| #export MORI_SHMEM_HEAP_SIZE="${MORI_SHMEM_HEAP_SIZE:-16G}" | ||
| export SGLANG_MORI_FP8_DISP="${SGLANG_MORI_FP8_DISP:-True}" | ||
| if [[ "${MODEL_NAME:-}" == *mxfp4* ]]; then | ||
| export SGLANG_MORI_FP8_DISP=False | ||
| fi | ||
|
|
||
| export SGLANG_MORI_FP4_DISP="${SGLANG_MORI_FP4_DISP:-False}" | ||
| export SGLANG_MORI_FP8_COMB="${SGLANG_MORI_FP8_COMB:-False}" | ||
|
|
||
| #if [[ "${MODEL_NAME:-}" == *mxfp4* ]]; then | ||
| # export MORI_MAX_DISPATCH_TOKENS_PREFILL="${MORI_MAX_DISPATCH_TOKENS_PREFILL:-12288}" | ||
| #else | ||
| # export MORI_MAX_DISPATCH_TOKENS_PREFILL="${MORI_MAX_DISPATCH_TOKENS_PREFILL:-16384}" | ||
| #fi | ||
| #export MORI_MAX_DISPATCH_TOKENS_DECODE="${MORI_MAX_DISPATCH_TOKENS_DECODE:-160}" | ||
|
|
||
| ##export SGLANG_MORI_DISPATCH_INTER_KERNEL_SWITCH_THRESHOLD="${SGLANG_MORI_DISPATCH_INTER_KERNEL_SWITCH_THRESHOLD:-$((MORI_MAX_DISPATCH_TOKENS_DECODE * 2))}" | ||
|
|
||
| #export MORI_EP_LAUNCH_CONFIG_MODE="${MORI_EP_LAUNCH_CONFIG_MODE:-AUTO}" | ||
| #export MORI_IO_QP_MAX_SEND_WR="${MORI_IO_QP_MAX_SEND_WR:-16384}" | ||
| #export MORI_IO_QP_MAX_CQE="${MORI_IO_QP_MAX_CQE:-32768}" | ||
| #export MORI_IO_QP_MAX_SGE="${MORI_IO_QP_MAX_SGE:-4}" | ||
|
|
||
| #export MORI_APP_LOG_LEVEL="${MORI_APP_LOG_LEVEL:-INFO}" | ||
| fi | ||
|
|
||
| export SGLANG_ROUTER_STDOUT_LOGS="${SGLANG_ROUTER_STDOUT_LOGS:-0}" | ||
|
|
||
| if [[ -d /sgl-workspace/aiter ]]; then | ||
| export PYTHONPATH="/sgl-workspace/aiter:${PYTHONPATH:-}" | ||
| fi |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a Dockerfile,
ARGdefault values are taken literally (quotes are not stripped). WithARG NIC_BACKEND="cx7", the value becomes"cx7", so the latercase "${NIC_BACKEND}" in cx7)will not match and will fall into the error branch. UseARG NIC_BACKEND=cx7(no quotes), and document quoting in thedocker build --build-argexample if needed.