Skip to content

Fix SGLang engine startup diagnostics and stagger init#572

Open
alisonshao wants to merge 3 commits intomainfrom
fix/sglang-engine-startup-diagnostics-and-stagger
Open

Fix SGLang engine startup diagnostics and stagger init#572
alisonshao wants to merge 3 commits intomainfrom
fix/sglang-engine-startup-diagnostics-and-stagger

Conversation

@alisonshao
Copy link
Collaborator

Summary

  • Added crash diagnostics to _wait_server_healthy: When an SGLang server process dies during startup, the error now includes the exit code, signal name (e.g. SIGKILL for OOM), elapsed time, and server URL. Previously it just said "Server process terminated unexpectedly." with no diagnostic info.
  • Staggered engine initialization: Rollout engines are now initialized sequentially instead of all at once. With RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1, simultaneous startup causes transient CUDA context allocations on shared GPUs, leading to OOM on machines with tight memory margins.

Root Cause

The 8-GPU megatron PPO test creates 4 SGLang engines (TP=2 each). All 4 were initialized simultaneously via ray.get([engine.init.remote(...) for ...]). Each engine subprocess can see all 8 GPUs during NCCL init, causing transient memory spikes. With mem_fraction_static=0.8, only ~2.47 GB remained available per H100 80GB — not enough margin for concurrent startup overhead.

Changes

File Change
miles/backends/sglang_utils/sglang_engine.py _wait_server_healthy now accepts a process param, reports exit code + signal on crash
miles/ray/rollout.py Engine init changed from parallel ray.get([...]) to sequential loop

Test plan

  • Run e2e/megatron/test_qwen3_4B_ppo.py on 8-GPU machine to verify engines start successfully
  • Verify crash diagnostics by checking logs when a server OOMs (should show exitcode=-9, killed by SIGKILL)
  • Verify external engine path (_init_external) still works with the updated _wait_server_healthy signature

…zation

1. _wait_server_healthy now reports exit code, signal name (e.g. SIGKILL for OOM),
   elapsed time, and URL when a server process crashes, instead of the bare
   "Server process terminated unexpectedly." message.

2. Rollout engines are now initialized sequentially instead of simultaneously.
   When all engines start at once with RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES=1,
   transient CUDA context allocations on shared GPUs can cause OOM on machines with
   tight memory margins (e.g. 2.47 GB available on H100 80GB with mem_fraction_static=0.8).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @alisonshao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves the robustness and stability of SGLang engine initialization, particularly in resource-constrained environments. It tackles a critical OOM issue during concurrent engine startup by staggering their initialization and provides much-needed diagnostic clarity when an engine fails to start, allowing for quicker identification and resolution of underlying problems.

Highlights

  • Improved SGLang Engine Crash Diagnostics: Enhanced the _wait_server_healthy function to provide detailed error messages when an SGLang server process terminates unexpectedly during startup. The diagnostics now include the exit code, signal name (e.g., SIGKILL for OOM), elapsed time, and server URL, making debugging OOM issues much clearer.
  • Sequential SGLang Engine Initialization: Modified the engine initialization process to be sequential rather than parallel. This change addresses Out-Of-Memory (OOM) issues that occurred on shared GPUs during simultaneous startup, caused by transient CUDA context allocations.
Changelog
  • miles/backends/sglang_utils/sglang_engine.py
    • The _wait_server_healthy function now accepts a process object to retrieve detailed termination information.
    • Implemented a new internal helper function _raise_terminated to generate comprehensive error messages upon server process termination, including exit codes, signal names, and OOM hints.
    • Replaced generic exceptions with calls to _raise_terminated for better diagnostic output.
  • miles/ray/rollout.py
    • Changed the SGLang engine initialization from a parallel ray.get call to a sequential loop, ensuring engines are initialized one after another.
    • Added a comment explaining that this sequential initialization prevents transient GPU memory spikes and OOM issues during simultaneous CUDA context creation.
Activity
  • The author, alisonshao, created this pull request to address SGLang engine startup issues.
  • The pull request introduces improved diagnostics for server crashes and modifies the engine initialization strategy to prevent OOM errors.
  • The author has provided a detailed summary of the changes, the root cause of the problem, and a test plan.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two valuable improvements. First, it significantly enhances the startup diagnostics for SGLang engines by providing detailed crash information, which will make debugging much easier. Second, it mitigates potential OOM issues by staggering engine initialization, moving from a parallel to a sequential startup process. The changes are well-reasoned and correctly implemented. I have one minor suggestion to improve code style.

Comment on lines +86 to +104
def _raise_terminated(phase):
elapsed = time.time() - start_time
exitcode = getattr(process, "exitcode", None) if process is not None else None
msg = (
f"Server process terminated unexpectedly during {phase}. "
f"exitcode={exitcode}, elapsed={elapsed:.1f}s, url={base_url}"
)
if exitcode is not None and exitcode < 0:
import signal

try:
sig_name = signal.Signals(-exitcode).name
msg += f" (killed by {sig_name})"
except (ValueError, AttributeError):
pass
if exitcode == -9:
msg += ". This is likely an OOM kill — check GPU memory availability."
logger.error(msg)
raise RuntimeError(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code style and readability, it's recommended to place imports at the top of the function scope where they are used, rather than inside conditional blocks. Moving import signal to the beginning of the _raise_terminated function aligns with this practice and makes the dependency clearer.

Suggested change
def _raise_terminated(phase):
elapsed = time.time() - start_time
exitcode = getattr(process, "exitcode", None) if process is not None else None
msg = (
f"Server process terminated unexpectedly during {phase}. "
f"exitcode={exitcode}, elapsed={elapsed:.1f}s, url={base_url}"
)
if exitcode is not None and exitcode < 0:
import signal
try:
sig_name = signal.Signals(-exitcode).name
msg += f" (killed by {sig_name})"
except (ValueError, AttributeError):
pass
if exitcode == -9:
msg += ". This is likely an OOM kill — check GPU memory availability."
logger.error(msg)
raise RuntimeError(msg)
def _raise_terminated(phase):
import signal
elapsed = time.time() - start_time
exitcode = getattr(process, "exitcode", None) if process is not None else None
msg = (
f"Server process terminated unexpectedly during {phase}. "
f"exitcode={exitcode}, elapsed={elapsed:.1f}s, url={base_url}"
)
if exitcode is not None and exitcode < 0:
try:
sig_name = signal.Signals(-exitcode).name
msg += f" (killed by {sig_name})"
except (ValueError, AttributeError):
pass
if exitcode == -9:
msg += ". This is likely an OOM kill — check GPU memory availability."
logger.error(msg)
raise RuntimeError(msg)

… daemon

Previously, NCCL_NVLS_ENABLE was set to "1" whenever NVLink was detected.
This force-enables NVLS (NVLink SHARP), which requires the IMEX daemon to
function. On H100 machines with NVLink but without IMEX, this causes:

  CUDA error 401 'the operation cannot be performed in the present state'
  Failed to bind NVLink SHARP (NVLS) Multicast memory

Now we only set NCCL_NVLS_ENABLE=0 when NVLink is absent. When NVLink is
present, we leave the variable unset so NCCL auto-detects NVLS support.
@alisonshao alisonshao requested a review from guapisolo as a code owner February 6, 2026 20:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant