Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 101 additions & 2 deletions rock/sdk/model/server/api/proxy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,109 @@
from typing import Any

from fastapi import APIRouter, Request
import httpx
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import JSONResponse

from rock.logger import init_logger
from rock.sdk.model.server.config import ModelServiceConfig
from rock.utils import retry_async

logger = init_logger(__name__)

proxy_router = APIRouter()


# Global HTTP client with a persistent connection pool
http_client = httpx.AsyncClient()


@retry_async(
max_attempts=6,
delay_seconds=2.0,
backoff=2.0, # Exponential backoff (2s, 4s, 8s, 16s, 32s).
jitter=True, # Adds randomness to prevent "thundering herd" effect on the backend.
exceptions=(httpx.TimeoutException, httpx.ConnectError, httpx.HTTPStatusError)
)
async def perform_llm_request(url: str, body: dict, headers: dict, config: ModelServiceConfig):
"""
Forwards the request and triggers retry ONLY if the status code
is in the explicit retryable whitelist.
"""
response = await http_client.post(url, json=body, headers=headers, timeout=config.request_timeout)
status_code = response.status_code

# Check against the explicit whitelist
if status_code in config.retryable_status_codes:
logger.warning(f"Retryable error detected: {status_code}. Triggering retry for {url}...")
response.raise_for_status()

return response


def get_base_url(model_name: str, config: ModelServiceConfig) -> str:
"""
Selects the target backend URL based on model name matching.
"""
if not model_name:
raise HTTPException(status_code=400, detail="Model name is required for routing.")

rules = config.proxy_rules
base_url = rules.get(model_name) or rules.get("default")
if not base_url:
raise HTTPException(status_code=400, detail=f"Model '{model_name}' is not configured and no 'default' rule found.")

return base_url.rstrip("/")


@proxy_router.post("/v1/chat/completions")
async def chat_completions(body: dict[str, Any], request: Request):
raise NotImplementedError("Proxy chat completions not implemented yet")
"""
OpenAI-compatible chat completions proxy endpoint.
Handles routing, header transparent forwarding, and automatic retries.
"""
config = request.app.state.model_service_config

# Step 1: Model Routing
model_name = body.get("model", "")
base_url = get_base_url(model_name, config)
target_url = f"{base_url}/chat/completions"
logger.info(f"Routing model '{model_name}' to URL: {target_url}")

# Step 2: Header Cleaning
# Preserve 'Authorization' for authentication while removing hop-by-hop transport headers.
forwarded_headers = {}
for key, value in request.headers.items():
if key.lower() in ["host", "content-length", "content-type", "transfer-encoding"]:
continue
forwarded_headers[key] = value

# Step 3: Strategy Enforcement
# Force non-streaming mode for the MVP phase to ensure stability.
body["stream"] = False

try:
# Step 4: Execute Request with Retry Logic
response = await perform_llm_request(target_url, body, forwarded_headers, config)
return JSONResponse(status_code=response.status_code, content=response.json())

except httpx.HTTPStatusError as e:
# Forward the raw backend error message to the client.
# This allows the Agent-side logic to detect keywords like 'context length exceeded'
# or 'content violation' and raise appropriate exceptions.
error_text = e.response.text if e.response else "No error details"
status_code = e.response.status_code if e.response else 502
logger.error(f"Final failure after retries. Status: {status_code}, Response: {error_text}")
return JSONResponse(
status_code=status_code,
content={
"error": {
"message": f"LLM backend error: {error_text}",
"type": "proxy_retry_failed",
"code": status_code
}
}
)
except Exception as e:
logger.error(f"Unexpected proxy error: {str(e)}")
# Raise standard 500 for non-HTTP related coding or system errors
raise HTTPException(status_code=500, detail=str(e))
46 changes: 46 additions & 0 deletions rock/sdk/model/server/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from pathlib import Path

import yaml
from pydantic import BaseModel, Field

from rock import env_vars

"""Configuration for LLM Service."""
Expand All @@ -20,3 +25,44 @@
RESPONSE_START_MARKER = "LLM_RESPONSE_START"
RESPONSE_END_MARKER = "LLM_RESPONSE_END"
SESSION_END_MARKER = "SESSION_END"


class ModelServiceConfig(BaseModel):
proxy_rules: dict[str, str] = Field(
default_factory=lambda: {
"gpt-3.5-turbo": "https://api.openai.com/v1",
"default": "https://api-inference.modelscope.cn/v1"
}
)

# Only these codes will trigger a retry.
# Codes not in this list (e.g., 400, 401, 403, or certain 5xx/6xx) will fail immediately.
retryable_status_codes: list[int] = Field(
default_factory=lambda: [429, 500]
)

request_timeout: int = 120

@classmethod
def from_file(cls, config_path: str | None = None):
"""
Factory method to create a config instance.

Args:
config_path: Path to the YAML file. If None, returns default config.
"""
if not config_path:
return cls()

config_file = Path(config_path)

if not config_file.exists():
raise FileNotFoundError(f"Config file {config_file} not found")

with open(config_file, encoding="utf-8") as f:
config_data = yaml.safe_load(f)

if config_data is None:
return cls()

return cls(**config_data)
22 changes: 19 additions & 3 deletions rock/sdk/model/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rock.logger import init_logger
from rock.sdk.model.server.api.local import init_local_api, local_router
from rock.sdk.model.server.api.proxy import proxy_router
from rock.sdk.model.server.config import SERVICE_HOST, SERVICE_PORT
from rock.sdk.model.server.config import SERVICE_HOST, SERVICE_PORT, ModelServiceConfig

# Configure logging
logger = init_logger(__name__)
Expand All @@ -20,6 +20,17 @@
async def lifespan(app: FastAPI):
"""Application lifespan context manager."""
logger.info("LLM Service started")
config_path = getattr(app.state, "config_path", None)
if config_path:
try:
app.state.model_service_config = ModelServiceConfig.from_file(config_path)
logger.info(f"Model Service Config loaded from: {config_path}")
except Exception as e:
logger.error(f"Failed to load config from {config_path}: {e}")
raise e
else:
app.state.model_service_config = ModelServiceConfig()
logger.info("No config file specified. Using default config settings.")
yield
logger.info("LLM Service shutting down")

Expand Down Expand Up @@ -49,8 +60,9 @@ async def global_exception_handler(request, exc):
)


def main(model_servie_type: str):
def main(model_servie_type: str, config_file: str | None):
logger.info(f"Starting LLM Service on {SERVICE_HOST}:{SERVICE_PORT}, type: {model_servie_type}")
app.state.config_path = config_file
if model_servie_type == "local":
asyncio.run(init_local_api())
app.include_router(local_router, prefix="", tags=["local"])
Expand All @@ -64,7 +76,11 @@ def main(model_servie_type: str):
parser.add_argument(
"--type", type=str, choices=["local", "proxy"], default="local", help="Type of LLM service (local/proxy)"
)
parser.add_argument(
"--config-file", type=str, default=None, help="Path to the configuration YAML file. If not set, default values will be used."
)
args = parser.parse_args()
model_servie_type = args.type
config_file = args.config_file

main(model_servie_type)
main(model_servie_type, config_file)
Loading