diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 6a1f6af4a..c670b980e 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -840,6 +840,7 @@ def get_gateway_user_data(authorized_key: str) -> str: packages=[ "nginx", "python3.10-venv", + "python3-pip", # Add pip for sglang-router installation ], snap={"commands": [["install", "--classic", "certbot"]]}, runcmd=[ @@ -850,6 +851,8 @@ def get_gateway_user_data(authorized_key: str) -> str: "s/# server_names_hash_bucket_size 64;/server_names_hash_bucket_size 128;/", "/etc/nginx/nginx.conf", ], + # Install sglang-router system-wide. Can be conditionally installed in the future. + ["pip", "install", "sglang-router"], ["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands())], ], ssh_authorized_keys=[authorized_key], @@ -979,7 +982,8 @@ def get_dstack_gateway_wheel(build: str) -> str: r.raise_for_status() build = r.text.strip() logger.debug("Found the latest gateway build: %s", build) - return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + # return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + return "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.0-py3-none-any.whl" def get_dstack_gateway_commands() -> List[str]: diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 6a480b580..159ada65c 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -50,6 +50,7 @@ class GatewayConfiguration(CoreModel): default: Annotated[bool, Field(description="Make the gateway default")] = False backend: Annotated[BackendType, Field(description="The gateway backend")] region: Annotated[str, Field(description="The gateway region")] + router: Annotated[Optional[str], Field(description="The router type, e.g. `sglang`")] = None domain: Annotated[ Optional[str], Field(description="The gateway domain, e.g. `example.com`") ] = None diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 index b096fa80e..8c50db11e 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 @@ -4,9 +4,13 @@ limit_req_zone {{ zone.key }} zone={{ zone.name }}:10m rate={{ zone.rpm }}r/m; {% if replicas %} upstream {{ domain }}.upstream { + {% if router == "sglang" %} + server 127.0.0.1:3000; # SGLang router on the gateway + {% else %} {% for replica in replicas %} server unix:{{ replica.socket }}; # replica {{ replica.id }} {% endfor %} + {% endif %} } {% else %} diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 new file mode 100644 index 000000000..a6d612d36 --- /dev/null +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 @@ -0,0 +1,23 @@ +{% for replica in replicas %} +# Worker {{ loop.index }} +upstream sglang_worker_{{ loop.index }}_upstream { + server unix:{{ replica.socket }}; +} + +server { + listen 127.0.0.1:{{ 10000 + loop.index }}; + access_log off; # disable access logs for this internal endpoint + + proxy_read_timeout 300s; + proxy_send_timeout 300s; + + location / { + proxy_pass http://sglang_worker_{{ loop.index }}_upstream; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header Connection ""; + proxy_set_header Upgrade $http_upgrade; + } +} +{% endfor %} diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index e1bfa4ff2..dd4f63f32 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -36,6 +36,7 @@ async def register_service( model=body.options.openai.model if body.options.openai is not None else None, ssh_private_key=body.ssh_private_key, repo=repo, + router=body.router, nginx=nginx, service_conn_pool=service_conn_pool, ) diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 8ab69b6af..117152a95 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -44,6 +44,7 @@ class RegisterServiceRequest(BaseModel): options: Options ssh_private_key: str rate_limits: tuple[RateLimit, ...] = () + router: Optional[str] = None class RegisterReplicaRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index 2d3e755ac..edd3b2d21 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -1,13 +1,16 @@ import importlib.resources +import json +import shutil import subprocess import tempfile +import time +import urllib.parse from asyncio import Lock from pathlib import Path from typing import Optional import jinja2 from pydantic import BaseModel -from typing_extensions import Literal from dstack._internal.proxy.gateway.const import PROXY_PORT_ON_GATEWAY from dstack._internal.proxy.gateway.models import ACMESettings @@ -56,18 +59,17 @@ class LocationConfig(BaseModel): class ServiceConfig(SiteConfig): - type: Literal["service"] = "service" project_name: str auth: bool client_max_body_size: int access_log_path: Path limit_req_zones: list[LimitReqZoneConfig] locations: list[LocationConfig] - replicas: list[ReplicaConfig] + replicas: list[ReplicaConfig] = [] + router: Optional[str] = None class ModelEntrypointConfig(SiteConfig): - type: Literal["entrypoint"] = "entrypoint" project_name: str @@ -81,11 +83,14 @@ def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: logger.debug("Registering %s domain %s", conf.type, conf.domain) conf_name = self.get_config_name(conf.domain) - async with self._lock: if conf.https: await run_async(self.run_certbot, conf.domain, acme) await run_async(self.write_conf, conf.render(), conf_name) + if isinstance(conf, ServiceConfig) and conf.router == "sglang": + replicas = len(conf.replicas) if conf.replicas else 1 + await run_async(self.write_sglang_workers_conf, conf) + await run_async(self.start_or_update_sglang_router, replicas) logger.info("Registered %s domain %s", conf.type, conf.domain) @@ -96,6 +101,14 @@ async def unregister(self, domain: str) -> None: return async with self._lock: await run_async(sudo_rm, conf_path) + workers_conf_path = self._conf_dir / f"sglang-workers.{domain}.conf" + if workers_conf_path.exists(): + await run_async(sudo_rm, workers_conf_path) + # Check if this was the last sglang service + if self._is_last_sglang_service(): + await run_async(self.stop_sglang_router) + else: + await run_async(self.update_sglang_router_workers, 0) await run_async(self.reload) logger.info("Unregistered domain %s", domain) @@ -168,6 +181,256 @@ def write_global_conf(self) -> None: conf = read_package_resource("00-log-format.conf") self.write_conf(conf, "00-log-format.conf") + def write_sglang_workers_conf(self, conf: ServiceConfig) -> None: + workers_config = generate_sglang_workers_config(conf) + workers_conf_name = f"sglang-workers.{conf.domain}.conf" + workers_conf_path = self._conf_dir / workers_conf_name + sudo_write(workers_conf_path, workers_config) + self.reload() + + @staticmethod + def get_sglang_router_workers() -> list[dict]: + try: + result = subprocess.run( + ["curl", "-s", "http://localhost:3000/workers"], capture_output=True, timeout=5 + ) + if result.returncode == 0: + response = json.loads(result.stdout.decode()) + return response.get("workers", []) + return [] + except Exception as e: + logger.error(f"Error getting sglang router workers: {e}") + return [] + + @staticmethod + def add_sglang_router_worker(worker_url: str) -> bool: + try: + payload = {"url": worker_url, "worker_type": "regular"} + result = subprocess.run( + [ + "curl", + "-X", + "POST", + "http://localhost:3000/workers", + "-H", + "Content-Type: application/json", + "-d", + json.dumps(payload), + ], + capture_output=True, + timeout=5, + ) + + if result.returncode == 0: + response = json.loads(result.stdout.decode()) + if response.get("status") == "accepted": + logger.info("Added worker %s to sglang router (queued)", worker_url) + return True + else: + logger.error("Failed to add worker %s: %s", worker_url, response) + return False + else: + logger.error("Failed to add worker %s: %s", worker_url, result.stderr.decode()) + return False + except Exception as e: + logger.error(f"Error adding worker {worker_url}: {e}") + return False + + @staticmethod + def remove_sglang_router_worker(worker_url: str) -> bool: + """Remove a single worker from sglang router""" + try: + # URL encode the worker URL for the DELETE request + encoded_url = urllib.parse.quote(worker_url, safe="") + + result = subprocess.run( + ["curl", "-X", "DELETE", f"http://localhost:3000/workers/{encoded_url}"], + capture_output=True, + timeout=5, + ) + + if result.returncode == 0: + response = json.loads(result.stdout.decode()) + if response.get("status") == "accepted": + logger.info("Removed worker %s from sglang router (queued)", worker_url) + return True + else: + logger.error("Failed to remove worker %s: %s", worker_url, response) + return False + else: + logger.error("Failed to remove worker %s: %s", worker_url, result.stderr.decode()) + return False + except Exception as e: + logger.error(f"Error removing worker {worker_url}: {e}") + return False + + @staticmethod + def is_sglang_router_running() -> bool: + """Check if sglang router is running""" + try: + result = subprocess.run( + ["pgrep", "-f", "sglang::router"], capture_output=True, timeout=5 + ) + return result.returncode == 0 + except Exception as e: + logger.error(f"Error checking sglang router status: {e}") + return False + + @staticmethod + def start_sglang_router() -> None: + """Start sglang router without workers""" + try: + # Kill existing sglang router if running + if Nginx.is_sglang_router_running(): + logger.info("Stopping existing sglang-router...") + subprocess.run(["pkill", "-f", "sglang::router"], timeout=5) + time.sleep(1) + + # Start sglang router without workers + logger.info("Starting sglang-router...") + cmd = [ + "python3", + "-m", + "sglang_router.launch_router", + "--host", + "0.0.0.0", + "--port", + "3000", + "--log-level", + "debug", + "--log-dir", + "./router_logs", + "--request-timeout-secs", + "1800", + ] + subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + # Wait for router to start + time.sleep(2) + + # Verify router is running + if not Nginx.is_sglang_router_running(): + raise Exception("Failed to start sglang router") + + logger.info("Sglang router started successfully") + + except Exception as e: + logger.error(f"Failed to start sglang-router: {e}") + raise + + @staticmethod + def stop_sglang_router() -> None: + """Stop sglang router and clean up""" + try: + logger.info("Stopping sglang-router...") + subprocess.run(["pkill", "-f", "sglang::router"], timeout=5) + + # Clean up router logs + router_logs_path = Path("./router_logs") + if router_logs_path.exists(): + shutil.rmtree(router_logs_path) + + logger.info("Sglang router stopped and cleaned up") + + except Exception as e: + logger.error(f"Error stopping sglang router: {e}") + + @staticmethod + def start_or_update_sglang_router(replicas: int) -> None: + """Start or update sglang router with worker URLs based on replicas""" + try: + # Validate replicas count + if replicas < 0: + logger.error("Invalid replica count: %d", replicas) + return + + # Check if sglang router is already running + if not Nginx.is_sglang_router_running(): + logger.info("Sglang router not running, starting with %d replicas", replicas) + Nginx.start_sglang_router() + if replicas > 0: + Nginx.update_sglang_router_workers(replicas) + return + + # Router is running, just update workers + logger.info("Sglang router is running, updating to %d replicas", replicas) + Nginx.update_sglang_router_workers(replicas) + + except Exception as e: + logger.error(f"Error updating sglang router: {e}") + # Fallback: restart the router + logger.warning("Falling back to router restart") + try: + Nginx.stop_sglang_router() + Nginx.start_sglang_router() + if replicas > 0: + Nginx.update_sglang_router_workers(replicas) + except Exception as fallback_error: + logger.error(f"Fallback also failed: {fallback_error}") + + @staticmethod + def _is_last_sglang_service() -> bool: + """Check if this is the last sglang service (for cleanup)""" + try: + # Count sglang-worker config files + sglang_configs = list(Path("/etc/nginx/sites-enabled").glob("sglang-workers.*.conf")) + return len(sglang_configs) <= 1 # Current one being removed + except Exception as e: + logger.error(f"Error checking sglang services: {e}") + return True # Assume last service for safety + + @staticmethod + def update_sglang_router_workers(replicas: int) -> None: + """Update sglang router workers via HTTP API""" + try: + # Get current workers + current_workers = Nginx.get_sglang_router_workers() + current_worker_urls = {worker["url"] for worker in current_workers} + current_count = len(current_worker_urls) + + if current_count == replicas: + logger.info("Sglang router already has %d workers, no update needed", replicas) + return + + # Calculate target worker URLs + target_worker_urls = {f"http://127.0.0.1:{10000 + i}" for i in range(1, replicas + 1)} + + # Workers to add + workers_to_add = target_worker_urls - current_worker_urls + # Workers to remove + workers_to_remove = current_worker_urls - target_worker_urls + + logger.info( + "Sglang router update: adding %d workers, removing %d workers", + len(workers_to_add), + len(workers_to_remove), + ) + + # Add new workers + for worker_url in sorted(workers_to_add): + success = Nginx.add_sglang_router_worker(worker_url) + if not success: + logger.warning("Failed to add worker %s, continuing with others", worker_url) + + # Remove old workers + for worker_url in sorted(workers_to_remove): + success = Nginx.remove_sglang_router_worker(worker_url) + if not success: + logger.warning( + "Failed to remove worker %s, continuing with others", worker_url + ) + + except Exception as e: + logger.error(f"Error updating sglang router workers: {e}") + + +def generate_sglang_workers_config(conf: ServiceConfig) -> str: + template = read_package_resource("sglang_workers.jinja2") + return jinja2.Template(template).render( + replicas=conf.replicas, + proxy_port=PROXY_PORT_ON_GATEWAY, + ) + def read_package_resource(file: str) -> str: return ( diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 3ea412d79..ce7cbca2e 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -44,6 +44,7 @@ async def register_service( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, + router: Optional[str] = None, ) -> None: service = models.Service( project_name=project_name, @@ -54,6 +55,7 @@ async def register_service( auth=auth, client_max_body_size=client_max_body_size, replicas=(), + router=router, ) async with lock: @@ -335,6 +337,7 @@ async def get_nginx_service_config( limit_req_zones=limit_req_zones, locations=locations, replicas=sorted(replicas, key=lambda r: r.id), # sort for reproducible configs + router=service.router, ) diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index 5cb5471d8..4e7046167 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -57,6 +57,7 @@ class Service(ImmutableModel): client_max_body_size: int # only enforced on gateways strip_prefix: bool = True # only used in-server replicas: tuple[Replica, ...] + router: Optional[str] = None @property def domain_safe(self) -> str: diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index f8c090079..33bfee3f5 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -45,6 +45,7 @@ async def register_service( options: dict, rate_limits: list[RateLimit], ssh_private_key: str, + router: Optional[str] = None, ): if "openai" in options: entrypoint = f"gateway.{domain.split('.', maxsplit=1)[1]}" @@ -59,6 +60,7 @@ async def register_service( "options": options, "rate_limits": [limit.dict() for limit in rate_limits], "ssh_private_key": ssh_private_key, + "router": router, } resp = await self._client.post( self._url(f"/api/registry/{project}/services/register"), json=payload diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index a8089a93a..05c1fa909 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -82,6 +82,7 @@ async def _register_service_in_gateway( gateway_configuration = get_gateway_configuration(gateway) service_https = _get_service_https(run_spec, gateway_configuration) + router = gateway_configuration.router service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: @@ -119,6 +120,7 @@ async def _register_service_in_gateway( options=service_spec.options, rate_limits=run_spec.configuration.rate_limits, ssh_private_key=run_model.project.ssh_private_key, + router=router, ) logger.info("%s: service is registered as %s", fmt(run_model), service_spec.url) except SSHError: