From 8175c9fdc72bee9702fb326e3274ce3dbef8231e Mon Sep 17 00:00:00 2001 From: fengling Date: Thu, 22 Jan 2026 03:27:21 +0000 Subject: [PATCH 1/5] feat: add model service proxy type --- rock/env_vars.py | 2 + rock/sdk/model/server/api/proxy.py | 114 ++++++++++++++++- rock/sdk/model/server/config.py | 36 ++++++ rock/sdk/model/server/main.py | 8 +- rock/sdk/model/server/model_proxy_config.yml | 24 ++++ tests/unit/sdk/model/test_proxy.py | 128 +++++++++++++++++++ 6 files changed, 309 insertions(+), 3 deletions(-) create mode 100644 rock/sdk/model/server/model_proxy_config.yml create mode 100644 tests/unit/sdk/model/test_proxy.py diff --git a/rock/env_vars.py b/rock/env_vars.py index 5062a5b28..5212cf621 100644 --- a/rock/env_vars.py +++ b/rock/env_vars.py @@ -41,6 +41,7 @@ # Model Service Config ROCK_MODEL_SERVICE_DATA_DIR: str + ROCK_MODEL_PROXY_CONFIG: str # Agentic ROCK_AGENT_PRE_INIT_BASH_CMD_LIST: list[str] = [] @@ -87,6 +88,7 @@ "ROCK_CLI_DEFAULT_CONFIG_PATH", Path.home() / ".rock" / "config.ini" ), "ROCK_MODEL_SERVICE_DATA_DIR": lambda: os.getenv("ROCK_MODEL_SERVICE_DATA_DIR", "/data/logs"), + "ROCK_MODEL_PROXY_CONFIG": lambda: os.getenv("ROCK_MODEL_PROXY_CONFIG", str(Path(__file__) / "model_proxy_config.yml")), "ROCK_AGENT_PYTHON_INSTALL_CMD": lambda: os.getenv( "ROCK_AGENT_PYTHON_INSTALL_CMD", "[ -f cpython31114.tar.gz ] && rm cpython31114.tar.gz; [ -d python ] && rm -rf python; wget -q -O cpython31114.tar.gz https://github.com/astral-sh/python-build-standalone/releases/download/20251120/cpython-3.11.14+20251120-x86_64-unknown-linux-gnu-install_only.tar.gz && tar -xzf cpython31114.tar.gz", diff --git a/rock/sdk/model/server/api/proxy.py b/rock/sdk/model/server/api/proxy.py index 9be4b7218..9719e5898 100644 --- a/rock/sdk/model/server/api/proxy.py +++ b/rock/sdk/model/server/api/proxy.py @@ -1,10 +1,120 @@ from typing import Any -from fastapi import APIRouter, Request +import httpx +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse + +from rock.logger import init_logger +from rock.sdk.model.server.config import ModelProxyConfig +from rock.utils import retry_async + +logger = init_logger(__name__) proxy_router = APIRouter() +# Global HTTP client with a persistent connection pool +http_client = httpx.AsyncClient() + + +@retry_async( + max_attempts=6, + delay_seconds=2.0, + backoff=2.0, # Exponential backoff (2s, 4s, 8s, 16s, 32s). + jitter=True, # Adds randomness to prevent "thundering herd" effect on the backend. + exceptions=(httpx.TimeoutException, httpx.ConnectError, httpx.HTTPStatusError) +) +async def perform_llm_request(url: str, body: dict, headers: dict, config: ModelProxyConfig): + """ + Forwards the request and triggers retry ONLY if the status code + is in the explicit retryable whitelist. + """ + response = await http_client.post(url, json=body, headers=headers, timeout=config.request_timeout) + status_code = response.status_code + + # Check against the explicit whitelist + if status_code in config.retryable_status_codes: + logger.warning(f"Retryable error detected: {status_code}. Triggering retry for {url}...") + response.raise_for_status() + + return response + + +def get_target_url(model_name: str, config: ModelProxyConfig) -> str: + """ + Selects the target backend URL based on keyword matching in the model name. + Example: 'qwen-max' contains 'qwen', thus matching the Whale provider URL. + """ + rules = config.proxy_rules + if not model_name: + if "default" in rules: + return rules["default"] + raise HTTPException(status_code=400, detail="Model name is missing and no default rule found.") + + model_name_lower = model_name.lower() + + # Iterate through PROXY_RULES for keyword matching + for key, url in rules.items(): + if key != "default" and key in model_name_lower: + return url + + if "default" in rules: + return rules["default"] + + raise HTTPException( + status_code=400, + detail=f"Model '{model_name}' not configured in proxy rules." + ) + + @proxy_router.post("/v1/chat/completions") async def chat_completions(body: dict[str, Any], request: Request): - raise NotImplementedError("Proxy chat completions not implemented yet") + """ + OpenAI-compatible chat completions proxy endpoint. + Handles routing, header transparent forwarding, and automatic retries. + """ + config = request.app.state.model_proxy_config + + # Step 1: Model Routing + model_name = body.get("model", "") + target_url = get_target_url(model_name, config) + logger.info(f"Routing model '{model_name}' to URL: {target_url}") + + # Step 2: Header Cleaning + # Preserve 'Authorization' for authentication while removing hop-by-hop transport headers. + forwarded_headers = {} + for key, value in request.headers.items(): + if key.lower() in ["host", "content-length", "content-type", "transfer-encoding"]: + continue + forwarded_headers[key] = value + + # Step 3: Strategy Enforcement + # Force non-streaming mode for the MVP phase to ensure stability. + body["stream"] = False + + try: + # Step 4: Execute Request with Retry Logic + response = await perform_llm_request(target_url, body, forwarded_headers) + return JSONResponse(status_code=response.status_code, content=response.json()) + + except httpx.HTTPStatusError as e: + # Forward the raw backend error message to the client. + # This allows the Agent-side logic to detect keywords like 'context length exceeded' + # or 'content violation' and raise appropriate exceptions. + error_text = e.response.text if e.response else "No error details" + status_code = e.response.status_code if e.response else 502 + logger.error(f"Final failure after retries. Status: {status_code}, Response: {error_text}") + return JSONResponse( + status_code=status_code, + content={ + "error": { + "message": f"LLM backend error: {error_text}", + "type": "proxy_retry_failed", + "code": status_code + } + } + ) + except Exception as e: + logger.error(f"Unexpected proxy error: {str(e)}") + # Raise standard 500 for non-HTTP related coding or system errors + raise HTTPException(status_code=500, detail=str(e)) diff --git a/rock/sdk/model/server/config.py b/rock/sdk/model/server/config.py index a213179db..255ebf27d 100644 --- a/rock/sdk/model/server/config.py +++ b/rock/sdk/model/server/config.py @@ -1,3 +1,8 @@ +from pathlib import Path + +import yaml +from pydantic import BaseModel, Field + from rock import env_vars """Configuration for LLM Service.""" @@ -20,3 +25,34 @@ RESPONSE_START_MARKER = "LLM_RESPONSE_START" RESPONSE_END_MARKER = "LLM_RESPONSE_END" SESSION_END_MARKER = "SESSION_END" + + +class ModelProxyConfig(BaseModel): + proxy_rules: dict[str, str] = Field(default_factory=dict) + + # Only these codes will trigger a retry. + # Codes not in this list (e.g., 400, 401, 403, or certain 5xx/6xx) will fail immediately. + retryable_status_codes: list[int] = Field(default_factory=lambda: [429, 499]) + + request_timeout: int = 120 + + @classmethod + def from_env(cls, config_path: str | None = None): + if not config_path: + config_path = env_vars.ROCK_MODEL_PROXY_CONFIG + + if not config_path: + return cls() + + config_file = Path(config_path) + + if not config_file.exists(): + raise FileNotFoundError(f"Config file {config_file} not found") + + with open(config_file, encoding="utf-8") as f: + config_data = yaml.safe_load(f) + + if config_data is None: + return cls() + + return cls(**config_data) diff --git a/rock/sdk/model/server/main.py b/rock/sdk/model/server/main.py index dde7cf860..7c38a414e 100644 --- a/rock/sdk/model/server/main.py +++ b/rock/sdk/model/server/main.py @@ -7,10 +7,11 @@ from fastapi import FastAPI, status from fastapi.responses import JSONResponse +from rock import env_vars from rock.logger import init_logger from rock.sdk.model.server.api.local import init_local_api, local_router from rock.sdk.model.server.api.proxy import proxy_router -from rock.sdk.model.server.config import SERVICE_HOST, SERVICE_PORT +from rock.sdk.model.server.config import SERVICE_HOST, SERVICE_PORT, ModelProxyConfig # Configure logging logger = init_logger(__name__) @@ -20,6 +21,11 @@ async def lifespan(app: FastAPI): """Application lifespan context manager.""" logger.info("LLM Service started") + if model_servie_type == "proxy": + config_file_path = env_vars.ROCK_MODEL_PROXY_CONFIG if env_vars.ROCK_MODEL_PROXY_CONFIG else "model_proxy_config.yml" + model_proxy_config = ModelProxyConfig.from_env(config_file_path) + app.state.model_proxy_config = model_proxy_config + logger.info(f"Model Proxy Config loaded from {config_file_path}") yield logger.info("LLM Service shutting down") diff --git a/rock/sdk/model/server/model_proxy_config.yml b/rock/sdk/model/server/model_proxy_config.yml new file mode 100644 index 000000000..376516d05 --- /dev/null +++ b/rock/sdk/model/server/model_proxy_config.yml @@ -0,0 +1,24 @@ +proxy_rules: + qwen: "https://offline-whale-wave.alibaba-inc.com/api/v2/services/aigc/text-generation/v1/chat/completions" + gpt: "https://api.openai.com/v1/chat/completions" +retryable_status_codes: + - 409, # ENGINE_CONCURRENCY_CONFLICT + - 429, # THROTTLING + - 499, # CLIENT_DISCONNECT + - 518, # ROUTE_CONFIG_ERROR + - 519, # NETTY_CATCH_ERROR + - 520, # PYRTP_CONNECT_ERROR + - 523, # READ_TIME_OUT + - 524, # RETRY_EXHAUSTED + - 527, # CONNECTION_RESET_EXCEPTION + - 528, # ENGINE_RESPONSE_TOO_LARGE + - 600, # BALANCING_WORKER_EXCEPTION + - 602, # MALLOC_EXCEPTION + - 603, # ENGINE_TIMEOUT_EXCEPTION + - 8000, # ENGINE_ABNORMAL_DISCONNECT_EXCEPTION + - 8200, # GET_HOST_FAILED + - 8202, # CONNECT_FAILED + - 8211, # DECODE_MALLOC_FAILED + - 8213, # WAIT_TO_RUN_TIMEOUT + - 8303 # CACHE_STORE_LOAD_SEND_REQUEST_FAILED +timeout: 180 \ No newline at end of file diff --git a/tests/unit/sdk/model/test_proxy.py b/tests/unit/sdk/model/test_proxy.py new file mode 100644 index 000000000..2102d19a4 --- /dev/null +++ b/tests/unit/sdk/model/test_proxy.py @@ -0,0 +1,128 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from fastapi import FastAPI, Request +from httpx import ASGITransport, AsyncClient, HTTPStatusError, Request, Response + +from rock.sdk.model.server.api.proxy import perform_llm_request, proxy_router +from rock.sdk.model.server.config import ModelProxyConfig + +# Initialize a temporary FastAPI application for testing the router +test_app = FastAPI() +test_app.include_router(proxy_router) + +mock_config = ModelProxyConfig( + proxy_rules={ + "qwen": "http://whale.url", + "default": "http://default.url" + }, + retryable_status_codes=[429, 499], + request_timeout=60 +) +test_app.state.model_proxy_config = mock_config + +@pytest.mark.asyncio +async def test_chat_completions_routing(): + """ + Test the high-level routing logic. + """ + patch_path = 'rock.sdk.model.server.api.proxy.perform_llm_request' + + with patch(patch_path, new_callable=AsyncMock) as mock_request: + mock_resp = MagicMock(spec=Response) + mock_resp.status_code = 200 + mock_resp.json.return_value = {"id": "chat-123", "choices": []} + mock_request.return_value = mock_resp + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + payload = { + "model": "Qwen2.5-72B", + "messages": [{"role": "user", "content": "hello"}] + } + response = await ac.post("/v1/chat/completions", json=payload) + + assert response.status_code == 200 + call_args = mock_request.call_args[0] + assert call_args[0] == "http://whale.url" + assert mock_request.called + + +@pytest.mark.asyncio +async def test_perform_llm_request_retry_on_whitelist(): + """ + Test that the proxy retries when receiving a whitelisted error code. + """ + client_post_path = 'rock.sdk.model.server.api.proxy.http_client.post' + + # Patch asyncio.sleep inside the retry module to avoid actual waiting + with patch(client_post_path, new_callable=AsyncMock) as mock_post, \ + patch('rock.utils.retry.asyncio.sleep', return_value=None): + + # 1. Setup Failed Response (429) + resp_429 = MagicMock(spec=Response) + resp_429.status_code = 429 + error_429 = HTTPStatusError( + "Rate Limited", + request=MagicMock(spec=Request), + response=resp_429 + ) + + # 2. Setup Success Response (200) + resp_200 = MagicMock(spec=Response) + resp_200.status_code = 200 + resp_200.json.return_value = {"ok": True} + + # Sequence: Fail with 429, then Succeed with 200 + mock_post.side_effect = [error_429, resp_200] + + result = await perform_llm_request("http://fake.url", {}, {}, mock_config) + + assert result.status_code == 200 + assert mock_post.call_count == 2 + + +@pytest.mark.asyncio +async def test_perform_llm_request_no_retry_on_non_whitelist(): + """ + Test that the proxy DOES NOT retry for non-retryable codes (e.g., 401). + It should return the error response immediately. + """ + client_post_path = 'rock.sdk.model.server.api.proxy.http_client.post' + + with patch(client_post_path, new_callable=AsyncMock) as mock_post: + # Mock 401 Unauthorized (NOT in the retry whitelist) + resp_401 = MagicMock(spec=Response) + resp_401.status_code = 401 + resp_401.json.return_value = {"error": "Invalid API Key"} + + # The function should return this response directly + mock_post.return_value = resp_401 + + result = await perform_llm_request("http://fake.url", {}, {}, mock_config) + + assert result.status_code == 401 + # Call count must be 1, meaning no retries were attempted + assert mock_post.call_count == 1 + + +@pytest.mark.asyncio +async def test_perform_llm_request_network_timeout_retry(): + """ + Test that network-level exceptions (like Timeout) also trigger retries. + """ + client_post_path = 'rock.sdk.model.server.api.proxy.http_client.post' + + with patch(client_post_path, new_callable=AsyncMock) as mock_post, \ + patch('rock.utils.retry.asyncio.sleep', return_value=None): + + resp_200 = MagicMock(spec=Response) + resp_200.status_code = 200 + + mock_post.side_effect = [httpx.TimeoutException("Network Timeout"), resp_200] + + result = await perform_llm_request("http://fake.url", {}, {}, mock_config) + + assert result.status_code == 200 + assert mock_post.call_count == 2 From 550f91384e3309171d324d42655ea1286bfe2261 Mon Sep 17 00:00:00 2001 From: fengling Date: Thu, 22 Jan 2026 08:41:47 +0000 Subject: [PATCH 2/5] feat: add config-file arg for model service and add more ut --- rock/env_vars.py | 2 - rock/sdk/model/server/api/proxy.py | 31 ++---- rock/sdk/model/server/config.py | 25 +++-- rock/sdk/model/server/main.py | 25 +++-- rock/sdk/model/server/model_proxy_config.yml | 24 ----- tests/unit/sdk/model/test_proxy.py | 99 ++++++++++++++++---- 6 files changed, 126 insertions(+), 80 deletions(-) delete mode 100644 rock/sdk/model/server/model_proxy_config.yml diff --git a/rock/env_vars.py b/rock/env_vars.py index 5212cf621..5062a5b28 100644 --- a/rock/env_vars.py +++ b/rock/env_vars.py @@ -41,7 +41,6 @@ # Model Service Config ROCK_MODEL_SERVICE_DATA_DIR: str - ROCK_MODEL_PROXY_CONFIG: str # Agentic ROCK_AGENT_PRE_INIT_BASH_CMD_LIST: list[str] = [] @@ -88,7 +87,6 @@ "ROCK_CLI_DEFAULT_CONFIG_PATH", Path.home() / ".rock" / "config.ini" ), "ROCK_MODEL_SERVICE_DATA_DIR": lambda: os.getenv("ROCK_MODEL_SERVICE_DATA_DIR", "/data/logs"), - "ROCK_MODEL_PROXY_CONFIG": lambda: os.getenv("ROCK_MODEL_PROXY_CONFIG", str(Path(__file__) / "model_proxy_config.yml")), "ROCK_AGENT_PYTHON_INSTALL_CMD": lambda: os.getenv( "ROCK_AGENT_PYTHON_INSTALL_CMD", "[ -f cpython31114.tar.gz ] && rm cpython31114.tar.gz; [ -d python ] && rm -rf python; wget -q -O cpython31114.tar.gz https://github.com/astral-sh/python-build-standalone/releases/download/20251120/cpython-3.11.14+20251120-x86_64-unknown-linux-gnu-install_only.tar.gz && tar -xzf cpython31114.tar.gz", diff --git a/rock/sdk/model/server/api/proxy.py b/rock/sdk/model/server/api/proxy.py index 9719e5898..f9c2ab778 100644 --- a/rock/sdk/model/server/api/proxy.py +++ b/rock/sdk/model/server/api/proxy.py @@ -5,7 +5,7 @@ from fastapi.responses import JSONResponse from rock.logger import init_logger -from rock.sdk.model.server.config import ModelProxyConfig +from rock.sdk.model.server.config import ModelServiceConfig from rock.utils import retry_async logger = init_logger(__name__) @@ -24,7 +24,7 @@ jitter=True, # Adds randomness to prevent "thundering herd" effect on the backend. exceptions=(httpx.TimeoutException, httpx.ConnectError, httpx.HTTPStatusError) ) -async def perform_llm_request(url: str, body: dict, headers: dict, config: ModelProxyConfig): +async def perform_llm_request(url: str, body: dict, headers: dict, config: ModelServiceConfig): """ Forwards the request and triggers retry ONLY if the status code is in the explicit retryable whitelist. @@ -40,31 +40,18 @@ async def perform_llm_request(url: str, body: dict, headers: dict, config: Model return response -def get_target_url(model_name: str, config: ModelProxyConfig) -> str: +def get_target_url(model_name: str, config: ModelServiceConfig) -> str: """ - Selects the target backend URL based on keyword matching in the model name. - Example: 'qwen-max' contains 'qwen', thus matching the Whale provider URL. + Selects the target backend URL based on model name matching. """ rules = config.proxy_rules if not model_name: - if "default" in rules: - return rules["default"] - raise HTTPException(status_code=400, detail="Model name is missing and no default rule found.") + raise HTTPException(status_code=400, detail="Model name is required for routing.") - model_name_lower = model_name.lower() + if model_name in rules: + return rules[model_name] - # Iterate through PROXY_RULES for keyword matching - for key, url in rules.items(): - if key != "default" and key in model_name_lower: - return url - - if "default" in rules: - return rules["default"] - - raise HTTPException( - status_code=400, - detail=f"Model '{model_name}' not configured in proxy rules." - ) + raise HTTPException(status_code=400, detail=f"Model '{model_name}' not configured in proxy rules.") @proxy_router.post("/v1/chat/completions") @@ -73,7 +60,7 @@ async def chat_completions(body: dict[str, Any], request: Request): OpenAI-compatible chat completions proxy endpoint. Handles routing, header transparent forwarding, and automatic retries. """ - config = request.app.state.model_proxy_config + config = request.app.state.model_service_config # Step 1: Model Routing model_name = body.get("model", "") diff --git a/rock/sdk/model/server/config.py b/rock/sdk/model/server/config.py index 255ebf27d..aa6b55885 100644 --- a/rock/sdk/model/server/config.py +++ b/rock/sdk/model/server/config.py @@ -27,20 +27,29 @@ SESSION_END_MARKER = "SESSION_END" -class ModelProxyConfig(BaseModel): - proxy_rules: dict[str, str] = Field(default_factory=dict) - - # Only these codes will trigger a retry. +class ModelServiceConfig(BaseModel): + proxy_rules: dict[str, str] = Field( + default_factory=lambda: { + "gpt": "https://api.openai.com/v1/chat/completions", + } + ) + + # Only these codes will trigger a retry. # Codes not in this list (e.g., 400, 401, 403, or certain 5xx/6xx) will fail immediately. - retryable_status_codes: list[int] = Field(default_factory=lambda: [429, 499]) + retryable_status_codes: list[int] = Field( + default_factory=lambda: [429, 500] + ) request_timeout: int = 120 @classmethod - def from_env(cls, config_path: str | None = None): - if not config_path: - config_path = env_vars.ROCK_MODEL_PROXY_CONFIG + def from_file(cls, config_path: str | None = None): + """ + Factory method to create a config instance. + Args: + config_path: Path to the YAML file. If None, returns default config. + """ if not config_path: return cls() diff --git a/rock/sdk/model/server/main.py b/rock/sdk/model/server/main.py index 7c38a414e..52309afcf 100644 --- a/rock/sdk/model/server/main.py +++ b/rock/sdk/model/server/main.py @@ -7,11 +7,10 @@ from fastapi import FastAPI, status from fastapi.responses import JSONResponse -from rock import env_vars from rock.logger import init_logger from rock.sdk.model.server.api.local import init_local_api, local_router from rock.sdk.model.server.api.proxy import proxy_router -from rock.sdk.model.server.config import SERVICE_HOST, SERVICE_PORT, ModelProxyConfig +from rock.sdk.model.server.config import SERVICE_HOST, SERVICE_PORT, ModelServiceConfig # Configure logging logger = init_logger(__name__) @@ -21,11 +20,17 @@ async def lifespan(app: FastAPI): """Application lifespan context manager.""" logger.info("LLM Service started") - if model_servie_type == "proxy": - config_file_path = env_vars.ROCK_MODEL_PROXY_CONFIG if env_vars.ROCK_MODEL_PROXY_CONFIG else "model_proxy_config.yml" - model_proxy_config = ModelProxyConfig.from_env(config_file_path) - app.state.model_proxy_config = model_proxy_config - logger.info(f"Model Proxy Config loaded from {config_file_path}") + config_path = getattr(app.state, "config_path", None) + if config_path: + try: + app.state.model_service_config = ModelServiceConfig.from_file(config_path) + logger.info(f"Model Service Config loaded from: {config_path}") + except Exception as e: + logger.error(f"Failed to load config from {config_path}: {e}") + raise e + else: + app.state.model_service_config = ModelServiceConfig() + logger.info("No config file specified. Using default config settings.") yield logger.info("LLM Service shutting down") @@ -55,8 +60,9 @@ async def global_exception_handler(request, exc): ) -def main(model_servie_type: str): +def main(model_servie_type: str, config_file: str | None): logger.info(f"Starting LLM Service on {SERVICE_HOST}:{SERVICE_PORT}, type: {model_servie_type}") + app.state.config_path = config_file if model_servie_type == "local": asyncio.run(init_local_api()) app.include_router(local_router, prefix="", tags=["local"]) @@ -70,6 +76,9 @@ def main(model_servie_type: str): parser.add_argument( "--type", type=str, choices=["local", "proxy"], default="local", help="Type of LLM service (local/proxy)" ) + parser.add_argument( + "--config-file", type=str, default=None, help="Path to the configuration YAML file. If not set, default values will be used." + ) args = parser.parse_args() model_servie_type = args.type diff --git a/rock/sdk/model/server/model_proxy_config.yml b/rock/sdk/model/server/model_proxy_config.yml deleted file mode 100644 index 376516d05..000000000 --- a/rock/sdk/model/server/model_proxy_config.yml +++ /dev/null @@ -1,24 +0,0 @@ -proxy_rules: - qwen: "https://offline-whale-wave.alibaba-inc.com/api/v2/services/aigc/text-generation/v1/chat/completions" - gpt: "https://api.openai.com/v1/chat/completions" -retryable_status_codes: - - 409, # ENGINE_CONCURRENCY_CONFLICT - - 429, # THROTTLING - - 499, # CLIENT_DISCONNECT - - 518, # ROUTE_CONFIG_ERROR - - 519, # NETTY_CATCH_ERROR - - 520, # PYRTP_CONNECT_ERROR - - 523, # READ_TIME_OUT - - 524, # RETRY_EXHAUSTED - - 527, # CONNECTION_RESET_EXCEPTION - - 528, # ENGINE_RESPONSE_TOO_LARGE - - 600, # BALANCING_WORKER_EXCEPTION - - 602, # MALLOC_EXCEPTION - - 603, # ENGINE_TIMEOUT_EXCEPTION - - 8000, # ENGINE_ABNORMAL_DISCONNECT_EXCEPTION - - 8200, # GET_HOST_FAILED - - 8202, # CONNECT_FAILED - - 8211, # DECODE_MALLOC_FAILED - - 8213, # WAIT_TO_RUN_TIMEOUT - - 8303 # CACHE_STORE_LOAD_SEND_REQUEST_FAILED -timeout: 180 \ No newline at end of file diff --git a/tests/unit/sdk/model/test_proxy.py b/tests/unit/sdk/model/test_proxy.py index 2102d19a4..6706f02b0 100644 --- a/tests/unit/sdk/model/test_proxy.py +++ b/tests/unit/sdk/model/test_proxy.py @@ -2,28 +2,23 @@ import httpx import pytest +import yaml from fastapi import FastAPI, Request from httpx import ASGITransport, AsyncClient, HTTPStatusError, Request, Response from rock.sdk.model.server.api.proxy import perform_llm_request, proxy_router -from rock.sdk.model.server.config import ModelProxyConfig +from rock.sdk.model.server.config import ModelServiceConfig +from rock.sdk.model.server.main import lifespan # Initialize a temporary FastAPI application for testing the router test_app = FastAPI() test_app.include_router(proxy_router) -mock_config = ModelProxyConfig( - proxy_rules={ - "qwen": "http://whale.url", - "default": "http://default.url" - }, - retryable_status_codes=[429, 499], - request_timeout=60 -) -test_app.state.model_proxy_config = mock_config +mock_config = ModelServiceConfig() +test_app.state.model_service_config = mock_config @pytest.mark.asyncio -async def test_chat_completions_routing(): +async def test_chat_completions_routing_success(): """ Test the high-level routing logic. """ @@ -38,17 +33,34 @@ async def test_chat_completions_routing(): transport = ASGITransport(app=test_app) async with AsyncClient(transport=transport, base_url="http://test") as ac: payload = { - "model": "Qwen2.5-72B", + "model": "gpt", "messages": [{"role": "user", "content": "hello"}] } response = await ac.post("/v1/chat/completions", json=payload) assert response.status_code == 200 call_args = mock_request.call_args[0] - assert call_args[0] == "http://whale.url" + assert call_args[0] == "https://api.openai.com/v1/chat/completions" assert mock_request.called +@pytest.mark.asyncio +async def test_chat_completions_routing_not_configured_fail(): + """ + Test that the proxy fails when the model is not configured. + """ + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + payload = { + "model": "claude-3", + "messages": [{"role": "user", "content": "hello"}] + } + response = await ac.post("/v1/chat/completions", json=payload) + + assert response.status_code == 400 + assert "not configured" in response.json()["detail"] + + @pytest.mark.asyncio async def test_perform_llm_request_retry_on_whitelist(): """ @@ -64,8 +76,8 @@ async def test_perform_llm_request_retry_on_whitelist(): resp_429 = MagicMock(spec=Response) resp_429.status_code = 429 error_429 = HTTPStatusError( - "Rate Limited", - request=MagicMock(spec=Request), + "Rate Limited", + request=MagicMock(spec=Request), response=resp_429 ) @@ -104,7 +116,7 @@ async def test_perform_llm_request_no_retry_on_non_whitelist(): assert result.status_code == 401 # Call count must be 1, meaning no retries were attempted - assert mock_post.call_count == 1 + assert mock_post.call_count == 1 @pytest.mark.asyncio @@ -126,3 +138,58 @@ async def test_perform_llm_request_network_timeout_retry(): assert result.status_code == 200 assert mock_post.call_count == 2 + + +@pytest.mark.asyncio +async def test_lifespan_initialization_with_config(tmp_path): + """ + Test that the application correctly initializes and overrides defaults + when a valid configuration file path is provided. + """ + conf_file = tmp_path / "proxy.yml" + conf_file.write_text(yaml.dump({ + "proxy_rules": {"my-model": "http://custom-url"}, + "request_timeout": 50 + })) + + # Initialize App and simulate CLI argument passing via app.state + app = FastAPI(lifespan=lifespan) + app.state.config_path = str(conf_file) + + async with lifespan(app): + config = app.state.model_service_config + # Verify that the config reflects file content instead of defaults + assert config.proxy_rules["my-model"] == "http://custom-url" + assert config.request_timeout == 50 + assert "gpt" not in config.proxy_rules + + +@pytest.mark.asyncio +async def test_lifespan_initialization_no_config(): + """ + Test that the application initializes with default ModelServiceConfig + settings when no configuration file path is provided. + """ + app = FastAPI(lifespan=lifespan) + app.state.config_path = None + + async with lifespan(app): + config = app.state.model_service_config + # Verify that default rules (e.g., 'gpt') are loaded + assert "gpt" in config.proxy_rules + assert config.request_timeout == 120 + + +@pytest.mark.asyncio +async def test_lifespan_invalid_config_path(): + """ + Test that providing a non-existent configuration file path causes the + lifespan to raise a FileNotFoundError, ensuring fail-fast behavior. + """ + app = FastAPI(lifespan=lifespan) + app.state.config_path = "/tmp/non_existent_file.yml" + + # Expect FileNotFoundError to be raised during startup + with pytest.raises(FileNotFoundError): + async with lifespan(app): + pass From 5ecb1dd6236c56d2832e29611ed081713c74373b Mon Sep 17 00:00:00 2001 From: fengling Date: Thu, 22 Jan 2026 11:59:38 +0000 Subject: [PATCH 3/5] fix: update proxy url and add default fallback --- rock/sdk/model/server/api/proxy.py | 18 ++++---- rock/sdk/model/server/config.py | 3 +- tests/unit/sdk/model/test_proxy.py | 68 ++++++++++++++++++++++++------ 3 files changed, 66 insertions(+), 23 deletions(-) diff --git a/rock/sdk/model/server/api/proxy.py b/rock/sdk/model/server/api/proxy.py index f9c2ab778..0917987b6 100644 --- a/rock/sdk/model/server/api/proxy.py +++ b/rock/sdk/model/server/api/proxy.py @@ -40,18 +40,19 @@ async def perform_llm_request(url: str, body: dict, headers: dict, config: Model return response -def get_target_url(model_name: str, config: ModelServiceConfig) -> str: +def get_base_url(model_name: str, config: ModelServiceConfig) -> str: """ Selects the target backend URL based on model name matching. """ - rules = config.proxy_rules if not model_name: raise HTTPException(status_code=400, detail="Model name is required for routing.") - if model_name in rules: - return rules[model_name] + rules = config.proxy_rules + base_url = rules.get(model_name) or rules.get("default") + if not base_url: + raise HTTPException(status_code=400, detail=f"Model '{model_name}' is not configured and no 'default' rule found.") - raise HTTPException(status_code=400, detail=f"Model '{model_name}' not configured in proxy rules.") + return base_url.rstrip("/") @proxy_router.post("/v1/chat/completions") @@ -64,7 +65,8 @@ async def chat_completions(body: dict[str, Any], request: Request): # Step 1: Model Routing model_name = body.get("model", "") - target_url = get_target_url(model_name, config) + base_url = get_base_url(model_name, config) + target_url = f"{base_url}/v1/chat/completions" logger.info(f"Routing model '{model_name}' to URL: {target_url}") # Step 2: Header Cleaning @@ -81,12 +83,12 @@ async def chat_completions(body: dict[str, Any], request: Request): try: # Step 4: Execute Request with Retry Logic - response = await perform_llm_request(target_url, body, forwarded_headers) + response = await perform_llm_request(target_url, body, forwarded_headers, config) return JSONResponse(status_code=response.status_code, content=response.json()) except httpx.HTTPStatusError as e: # Forward the raw backend error message to the client. - # This allows the Agent-side logic to detect keywords like 'context length exceeded' + # This allows the Agent-side logic to detect keywords like 'context length exceeded' # or 'content violation' and raise appropriate exceptions. error_text = e.response.text if e.response else "No error details" status_code = e.response.status_code if e.response else 502 diff --git a/rock/sdk/model/server/config.py b/rock/sdk/model/server/config.py index aa6b55885..583bf3c19 100644 --- a/rock/sdk/model/server/config.py +++ b/rock/sdk/model/server/config.py @@ -30,7 +30,8 @@ class ModelServiceConfig(BaseModel): proxy_rules: dict[str, str] = Field( default_factory=lambda: { - "gpt": "https://api.openai.com/v1/chat/completions", + "gpt-3.5-turbo": "https://api.openai.com", + "default": "https://api-inference.modelscope.cn" } ) diff --git a/tests/unit/sdk/model/test_proxy.py b/tests/unit/sdk/model/test_proxy.py index 6706f02b0..622f9cf1e 100644 --- a/tests/unit/sdk/model/test_proxy.py +++ b/tests/unit/sdk/model/test_proxy.py @@ -33,7 +33,7 @@ async def test_chat_completions_routing_success(): transport = ASGITransport(app=test_app) async with AsyncClient(transport=transport, base_url="http://test") as ac: payload = { - "model": "gpt", + "model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "hello"}] } response = await ac.post("/v1/chat/completions", json=payload) @@ -45,20 +45,60 @@ async def test_chat_completions_routing_success(): @pytest.mark.asyncio -async def test_chat_completions_routing_not_configured_fail(): +async def test_chat_completions_fallback_to_default_when_not_found(): """ - Test that the proxy fails when the model is not configured. + Test that an unrecognized model name correctly falls back to the 'default' URL. """ - transport = ASGITransport(app=test_app) - async with AsyncClient(transport=transport, base_url="http://test") as ac: - payload = { - "model": "claude-3", - "messages": [{"role": "user", "content": "hello"}] - } - response = await ac.post("/v1/chat/completions", json=payload) + patch_path = 'rock.sdk.model.server.api.proxy.perform_llm_request' + + with patch(patch_path, new_callable=AsyncMock) as mock_request: + mock_resp = MagicMock(spec=Response) + mock_resp.status_code = 200 + mock_resp.json.return_value = {"id": "chat-fallback", "choices": []} + mock_request.return_value = mock_resp + + config = test_app.state.model_service_config + default_base_url = config.proxy_rules["default"].rstrip("/") + expected_target_url = f"{default_base_url}/v1/chat/completions" + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + payload = { + "model": "some-random-unsupported-model", # This model is NOT in proxy_rules + "messages": [{"role": "user", "content": "hello"}] + } + response = await ac.post("/v1/chat/completions", json=payload) + + assert response.status_code == 200 + + # Verify that perform_llm_request was called with the DEFAULT URL + call_args = mock_request.call_args[0] + actual_url = call_args[0] + + assert actual_url == expected_target_url + assert mock_request.called + + +@pytest.mark.asyncio +async def test_chat_completions_routing_absolute_fail(): + """ + Test that both the specific model and the 'default' rule are missing. + """ + empty_config = ModelServiceConfig() + empty_config.proxy_rules = {} + + with patch.object(test_app.state, 'model_service_config', empty_config): + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + payload = { + "model": "any-model", + "messages": [{"role": "user", "content": "hello"}] + } + response = await ac.post("/v1/chat/completions", json=payload) assert response.status_code == 400 - assert "not configured" in response.json()["detail"] + detail = response.json()["detail"] + assert "not configured" in detail @pytest.mark.asyncio @@ -161,7 +201,7 @@ async def test_lifespan_initialization_with_config(tmp_path): # Verify that the config reflects file content instead of defaults assert config.proxy_rules["my-model"] == "http://custom-url" assert config.request_timeout == 50 - assert "gpt" not in config.proxy_rules + assert "gpt-3.5-turbo" not in config.proxy_rules @pytest.mark.asyncio @@ -175,8 +215,8 @@ async def test_lifespan_initialization_no_config(): async with lifespan(app): config = app.state.model_service_config - # Verify that default rules (e.g., 'gpt') are loaded - assert "gpt" in config.proxy_rules + # Verify that default rules (e.g., 'gpt-3.5-turbo') are loaded + assert "gpt-3.5-turbo" in config.proxy_rules assert config.request_timeout == 120 From f49cd7af2de41ef92acb8bffc0a14fa930bd2d0e Mon Sep 17 00:00:00 2001 From: fengling Date: Thu, 22 Jan 2026 12:37:34 +0000 Subject: [PATCH 4/5] fix: add config file parameter to main function --- rock/sdk/model/server/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rock/sdk/model/server/main.py b/rock/sdk/model/server/main.py index 52309afcf..f1b871921 100644 --- a/rock/sdk/model/server/main.py +++ b/rock/sdk/model/server/main.py @@ -81,5 +81,6 @@ def main(model_servie_type: str, config_file: str | None): ) args = parser.parse_args() model_servie_type = args.type + config_file = args.config_file - main(model_servie_type) + main(model_servie_type, config_file) From f57feea3af60d66910dc8898038da05b778f6519 Mon Sep 17 00:00:00 2001 From: fengling Date: Fri, 23 Jan 2026 07:15:05 +0000 Subject: [PATCH 5/5] fix: update base_url --- rock/sdk/model/server/api/proxy.py | 2 +- rock/sdk/model/server/config.py | 4 ++-- tests/unit/sdk/model/test_proxy.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rock/sdk/model/server/api/proxy.py b/rock/sdk/model/server/api/proxy.py index 0917987b6..2d77f2a42 100644 --- a/rock/sdk/model/server/api/proxy.py +++ b/rock/sdk/model/server/api/proxy.py @@ -66,7 +66,7 @@ async def chat_completions(body: dict[str, Any], request: Request): # Step 1: Model Routing model_name = body.get("model", "") base_url = get_base_url(model_name, config) - target_url = f"{base_url}/v1/chat/completions" + target_url = f"{base_url}/chat/completions" logger.info(f"Routing model '{model_name}' to URL: {target_url}") # Step 2: Header Cleaning diff --git a/rock/sdk/model/server/config.py b/rock/sdk/model/server/config.py index 583bf3c19..a7e4603c6 100644 --- a/rock/sdk/model/server/config.py +++ b/rock/sdk/model/server/config.py @@ -30,8 +30,8 @@ class ModelServiceConfig(BaseModel): proxy_rules: dict[str, str] = Field( default_factory=lambda: { - "gpt-3.5-turbo": "https://api.openai.com", - "default": "https://api-inference.modelscope.cn" + "gpt-3.5-turbo": "https://api.openai.com/v1", + "default": "https://api-inference.modelscope.cn/v1" } ) diff --git a/tests/unit/sdk/model/test_proxy.py b/tests/unit/sdk/model/test_proxy.py index 622f9cf1e..26d674216 100644 --- a/tests/unit/sdk/model/test_proxy.py +++ b/tests/unit/sdk/model/test_proxy.py @@ -59,7 +59,7 @@ async def test_chat_completions_fallback_to_default_when_not_found(): config = test_app.state.model_service_config default_base_url = config.proxy_rules["default"].rstrip("/") - expected_target_url = f"{default_base_url}/v1/chat/completions" + expected_target_url = f"{default_base_url}/chat/completions" transport = ASGITransport(app=test_app) async with AsyncClient(transport=transport, base_url="http://test") as ac: