diff --git a/rock/sdk/model/server/api/proxy.py b/rock/sdk/model/server/api/proxy.py index 9be4b7218..2d77f2a42 100644 --- a/rock/sdk/model/server/api/proxy.py +++ b/rock/sdk/model/server/api/proxy.py @@ -1,10 +1,109 @@ from typing import Any -from fastapi import APIRouter, Request +import httpx +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import JSONResponse + +from rock.logger import init_logger +from rock.sdk.model.server.config import ModelServiceConfig +from rock.utils import retry_async + +logger = init_logger(__name__) proxy_router = APIRouter() +# Global HTTP client with a persistent connection pool +http_client = httpx.AsyncClient() + + +@retry_async( + max_attempts=6, + delay_seconds=2.0, + backoff=2.0, # Exponential backoff (2s, 4s, 8s, 16s, 32s). + jitter=True, # Adds randomness to prevent "thundering herd" effect on the backend. + exceptions=(httpx.TimeoutException, httpx.ConnectError, httpx.HTTPStatusError) +) +async def perform_llm_request(url: str, body: dict, headers: dict, config: ModelServiceConfig): + """ + Forwards the request and triggers retry ONLY if the status code + is in the explicit retryable whitelist. + """ + response = await http_client.post(url, json=body, headers=headers, timeout=config.request_timeout) + status_code = response.status_code + + # Check against the explicit whitelist + if status_code in config.retryable_status_codes: + logger.warning(f"Retryable error detected: {status_code}. Triggering retry for {url}...") + response.raise_for_status() + + return response + + +def get_base_url(model_name: str, config: ModelServiceConfig) -> str: + """ + Selects the target backend URL based on model name matching. + """ + if not model_name: + raise HTTPException(status_code=400, detail="Model name is required for routing.") + + rules = config.proxy_rules + base_url = rules.get(model_name) or rules.get("default") + if not base_url: + raise HTTPException(status_code=400, detail=f"Model '{model_name}' is not configured and no 'default' rule found.") + + return base_url.rstrip("/") + + @proxy_router.post("/v1/chat/completions") async def chat_completions(body: dict[str, Any], request: Request): - raise NotImplementedError("Proxy chat completions not implemented yet") + """ + OpenAI-compatible chat completions proxy endpoint. + Handles routing, header transparent forwarding, and automatic retries. + """ + config = request.app.state.model_service_config + + # Step 1: Model Routing + model_name = body.get("model", "") + base_url = get_base_url(model_name, config) + target_url = f"{base_url}/chat/completions" + logger.info(f"Routing model '{model_name}' to URL: {target_url}") + + # Step 2: Header Cleaning + # Preserve 'Authorization' for authentication while removing hop-by-hop transport headers. + forwarded_headers = {} + for key, value in request.headers.items(): + if key.lower() in ["host", "content-length", "content-type", "transfer-encoding"]: + continue + forwarded_headers[key] = value + + # Step 3: Strategy Enforcement + # Force non-streaming mode for the MVP phase to ensure stability. + body["stream"] = False + + try: + # Step 4: Execute Request with Retry Logic + response = await perform_llm_request(target_url, body, forwarded_headers, config) + return JSONResponse(status_code=response.status_code, content=response.json()) + + except httpx.HTTPStatusError as e: + # Forward the raw backend error message to the client. + # This allows the Agent-side logic to detect keywords like 'context length exceeded' + # or 'content violation' and raise appropriate exceptions. + error_text = e.response.text if e.response else "No error details" + status_code = e.response.status_code if e.response else 502 + logger.error(f"Final failure after retries. Status: {status_code}, Response: {error_text}") + return JSONResponse( + status_code=status_code, + content={ + "error": { + "message": f"LLM backend error: {error_text}", + "type": "proxy_retry_failed", + "code": status_code + } + } + ) + except Exception as e: + logger.error(f"Unexpected proxy error: {str(e)}") + # Raise standard 500 for non-HTTP related coding or system errors + raise HTTPException(status_code=500, detail=str(e)) diff --git a/rock/sdk/model/server/config.py b/rock/sdk/model/server/config.py index a213179db..a7e4603c6 100644 --- a/rock/sdk/model/server/config.py +++ b/rock/sdk/model/server/config.py @@ -1,3 +1,8 @@ +from pathlib import Path + +import yaml +from pydantic import BaseModel, Field + from rock import env_vars """Configuration for LLM Service.""" @@ -20,3 +25,44 @@ RESPONSE_START_MARKER = "LLM_RESPONSE_START" RESPONSE_END_MARKER = "LLM_RESPONSE_END" SESSION_END_MARKER = "SESSION_END" + + +class ModelServiceConfig(BaseModel): + proxy_rules: dict[str, str] = Field( + default_factory=lambda: { + "gpt-3.5-turbo": "https://api.openai.com/v1", + "default": "https://api-inference.modelscope.cn/v1" + } + ) + + # Only these codes will trigger a retry. + # Codes not in this list (e.g., 400, 401, 403, or certain 5xx/6xx) will fail immediately. + retryable_status_codes: list[int] = Field( + default_factory=lambda: [429, 500] + ) + + request_timeout: int = 120 + + @classmethod + def from_file(cls, config_path: str | None = None): + """ + Factory method to create a config instance. + + Args: + config_path: Path to the YAML file. If None, returns default config. + """ + if not config_path: + return cls() + + config_file = Path(config_path) + + if not config_file.exists(): + raise FileNotFoundError(f"Config file {config_file} not found") + + with open(config_file, encoding="utf-8") as f: + config_data = yaml.safe_load(f) + + if config_data is None: + return cls() + + return cls(**config_data) diff --git a/rock/sdk/model/server/main.py b/rock/sdk/model/server/main.py index dde7cf860..f1b871921 100644 --- a/rock/sdk/model/server/main.py +++ b/rock/sdk/model/server/main.py @@ -10,7 +10,7 @@ from rock.logger import init_logger from rock.sdk.model.server.api.local import init_local_api, local_router from rock.sdk.model.server.api.proxy import proxy_router -from rock.sdk.model.server.config import SERVICE_HOST, SERVICE_PORT +from rock.sdk.model.server.config import SERVICE_HOST, SERVICE_PORT, ModelServiceConfig # Configure logging logger = init_logger(__name__) @@ -20,6 +20,17 @@ async def lifespan(app: FastAPI): """Application lifespan context manager.""" logger.info("LLM Service started") + config_path = getattr(app.state, "config_path", None) + if config_path: + try: + app.state.model_service_config = ModelServiceConfig.from_file(config_path) + logger.info(f"Model Service Config loaded from: {config_path}") + except Exception as e: + logger.error(f"Failed to load config from {config_path}: {e}") + raise e + else: + app.state.model_service_config = ModelServiceConfig() + logger.info("No config file specified. Using default config settings.") yield logger.info("LLM Service shutting down") @@ -49,8 +60,9 @@ async def global_exception_handler(request, exc): ) -def main(model_servie_type: str): +def main(model_servie_type: str, config_file: str | None): logger.info(f"Starting LLM Service on {SERVICE_HOST}:{SERVICE_PORT}, type: {model_servie_type}") + app.state.config_path = config_file if model_servie_type == "local": asyncio.run(init_local_api()) app.include_router(local_router, prefix="", tags=["local"]) @@ -64,7 +76,11 @@ def main(model_servie_type: str): parser.add_argument( "--type", type=str, choices=["local", "proxy"], default="local", help="Type of LLM service (local/proxy)" ) + parser.add_argument( + "--config-file", type=str, default=None, help="Path to the configuration YAML file. If not set, default values will be used." + ) args = parser.parse_args() model_servie_type = args.type + config_file = args.config_file - main(model_servie_type) + main(model_servie_type, config_file) diff --git a/tests/unit/sdk/model/test_proxy.py b/tests/unit/sdk/model/test_proxy.py new file mode 100644 index 000000000..26d674216 --- /dev/null +++ b/tests/unit/sdk/model/test_proxy.py @@ -0,0 +1,235 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +import yaml +from fastapi import FastAPI, Request +from httpx import ASGITransport, AsyncClient, HTTPStatusError, Request, Response + +from rock.sdk.model.server.api.proxy import perform_llm_request, proxy_router +from rock.sdk.model.server.config import ModelServiceConfig +from rock.sdk.model.server.main import lifespan + +# Initialize a temporary FastAPI application for testing the router +test_app = FastAPI() +test_app.include_router(proxy_router) + +mock_config = ModelServiceConfig() +test_app.state.model_service_config = mock_config + +@pytest.mark.asyncio +async def test_chat_completions_routing_success(): + """ + Test the high-level routing logic. + """ + patch_path = 'rock.sdk.model.server.api.proxy.perform_llm_request' + + with patch(patch_path, new_callable=AsyncMock) as mock_request: + mock_resp = MagicMock(spec=Response) + mock_resp.status_code = 200 + mock_resp.json.return_value = {"id": "chat-123", "choices": []} + mock_request.return_value = mock_resp + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + payload = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}] + } + response = await ac.post("/v1/chat/completions", json=payload) + + assert response.status_code == 200 + call_args = mock_request.call_args[0] + assert call_args[0] == "https://api.openai.com/v1/chat/completions" + assert mock_request.called + + +@pytest.mark.asyncio +async def test_chat_completions_fallback_to_default_when_not_found(): + """ + Test that an unrecognized model name correctly falls back to the 'default' URL. + """ + patch_path = 'rock.sdk.model.server.api.proxy.perform_llm_request' + + with patch(patch_path, new_callable=AsyncMock) as mock_request: + mock_resp = MagicMock(spec=Response) + mock_resp.status_code = 200 + mock_resp.json.return_value = {"id": "chat-fallback", "choices": []} + mock_request.return_value = mock_resp + + config = test_app.state.model_service_config + default_base_url = config.proxy_rules["default"].rstrip("/") + expected_target_url = f"{default_base_url}/chat/completions" + + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + payload = { + "model": "some-random-unsupported-model", # This model is NOT in proxy_rules + "messages": [{"role": "user", "content": "hello"}] + } + response = await ac.post("/v1/chat/completions", json=payload) + + assert response.status_code == 200 + + # Verify that perform_llm_request was called with the DEFAULT URL + call_args = mock_request.call_args[0] + actual_url = call_args[0] + + assert actual_url == expected_target_url + assert mock_request.called + + +@pytest.mark.asyncio +async def test_chat_completions_routing_absolute_fail(): + """ + Test that both the specific model and the 'default' rule are missing. + """ + empty_config = ModelServiceConfig() + empty_config.proxy_rules = {} + + with patch.object(test_app.state, 'model_service_config', empty_config): + transport = ASGITransport(app=test_app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + payload = { + "model": "any-model", + "messages": [{"role": "user", "content": "hello"}] + } + response = await ac.post("/v1/chat/completions", json=payload) + + assert response.status_code == 400 + detail = response.json()["detail"] + assert "not configured" in detail + + +@pytest.mark.asyncio +async def test_perform_llm_request_retry_on_whitelist(): + """ + Test that the proxy retries when receiving a whitelisted error code. + """ + client_post_path = 'rock.sdk.model.server.api.proxy.http_client.post' + + # Patch asyncio.sleep inside the retry module to avoid actual waiting + with patch(client_post_path, new_callable=AsyncMock) as mock_post, \ + patch('rock.utils.retry.asyncio.sleep', return_value=None): + + # 1. Setup Failed Response (429) + resp_429 = MagicMock(spec=Response) + resp_429.status_code = 429 + error_429 = HTTPStatusError( + "Rate Limited", + request=MagicMock(spec=Request), + response=resp_429 + ) + + # 2. Setup Success Response (200) + resp_200 = MagicMock(spec=Response) + resp_200.status_code = 200 + resp_200.json.return_value = {"ok": True} + + # Sequence: Fail with 429, then Succeed with 200 + mock_post.side_effect = [error_429, resp_200] + + result = await perform_llm_request("http://fake.url", {}, {}, mock_config) + + assert result.status_code == 200 + assert mock_post.call_count == 2 + + +@pytest.mark.asyncio +async def test_perform_llm_request_no_retry_on_non_whitelist(): + """ + Test that the proxy DOES NOT retry for non-retryable codes (e.g., 401). + It should return the error response immediately. + """ + client_post_path = 'rock.sdk.model.server.api.proxy.http_client.post' + + with patch(client_post_path, new_callable=AsyncMock) as mock_post: + # Mock 401 Unauthorized (NOT in the retry whitelist) + resp_401 = MagicMock(spec=Response) + resp_401.status_code = 401 + resp_401.json.return_value = {"error": "Invalid API Key"} + + # The function should return this response directly + mock_post.return_value = resp_401 + + result = await perform_llm_request("http://fake.url", {}, {}, mock_config) + + assert result.status_code == 401 + # Call count must be 1, meaning no retries were attempted + assert mock_post.call_count == 1 + + +@pytest.mark.asyncio +async def test_perform_llm_request_network_timeout_retry(): + """ + Test that network-level exceptions (like Timeout) also trigger retries. + """ + client_post_path = 'rock.sdk.model.server.api.proxy.http_client.post' + + with patch(client_post_path, new_callable=AsyncMock) as mock_post, \ + patch('rock.utils.retry.asyncio.sleep', return_value=None): + + resp_200 = MagicMock(spec=Response) + resp_200.status_code = 200 + + mock_post.side_effect = [httpx.TimeoutException("Network Timeout"), resp_200] + + result = await perform_llm_request("http://fake.url", {}, {}, mock_config) + + assert result.status_code == 200 + assert mock_post.call_count == 2 + + +@pytest.mark.asyncio +async def test_lifespan_initialization_with_config(tmp_path): + """ + Test that the application correctly initializes and overrides defaults + when a valid configuration file path is provided. + """ + conf_file = tmp_path / "proxy.yml" + conf_file.write_text(yaml.dump({ + "proxy_rules": {"my-model": "http://custom-url"}, + "request_timeout": 50 + })) + + # Initialize App and simulate CLI argument passing via app.state + app = FastAPI(lifespan=lifespan) + app.state.config_path = str(conf_file) + + async with lifespan(app): + config = app.state.model_service_config + # Verify that the config reflects file content instead of defaults + assert config.proxy_rules["my-model"] == "http://custom-url" + assert config.request_timeout == 50 + assert "gpt-3.5-turbo" not in config.proxy_rules + + +@pytest.mark.asyncio +async def test_lifespan_initialization_no_config(): + """ + Test that the application initializes with default ModelServiceConfig + settings when no configuration file path is provided. + """ + app = FastAPI(lifespan=lifespan) + app.state.config_path = None + + async with lifespan(app): + config = app.state.model_service_config + # Verify that default rules (e.g., 'gpt-3.5-turbo') are loaded + assert "gpt-3.5-turbo" in config.proxy_rules + assert config.request_timeout == 120 + + +@pytest.mark.asyncio +async def test_lifespan_invalid_config_path(): + """ + Test that providing a non-existent configuration file path causes the + lifespan to raise a FileNotFoundError, ensuring fail-fast behavior. + """ + app = FastAPI(lifespan=lifespan) + app.state.config_path = "/tmp/non_existent_file.yml" + + # Expect FileNotFoundError to be raised during startup + with pytest.raises(FileNotFoundError): + async with lifespan(app): + pass