From 8c62dbabde8aea97ae280af7b42ca7bc5abcedd4 Mon Sep 17 00:00:00 2001 From: statxc <181730535+statxc@users.noreply.github.com> Date: Thu, 12 Mar 2026 13:30:27 +0000 Subject: [PATCH 1/3] feat: Detect platform-side inference errors so agents aren't penalized for provider failures --- inference_gateway/config.py | 6 + inference_gateway/error_hash_map.py | 50 +++++ inference_gateway/main.py | 47 +++- models/evaluation_run.py | 1 + tests/__init__.py | 0 tests/test_inference_error_tracking.py | 291 +++++++++++++++++++++++++ validator/main.py | 25 ++- 7 files changed, 416 insertions(+), 4 deletions(-) create mode 100644 inference_gateway/error_hash_map.py create mode 100644 tests/__init__.py create mode 100644 tests/test_inference_error_tracking.py diff --git a/inference_gateway/config.py b/inference_gateway/config.py index 77ac3b658..c93d9bd92 100644 --- a/inference_gateway/config.py +++ b/inference_gateway/config.py @@ -59,6 +59,12 @@ logger.fatal("MAX_COST_PER_EVALUATION_RUN_USD is not set in .env") MAX_COST_PER_EVALUATION_RUN_USD = float(MAX_COST_PER_EVALUATION_RUN_USD) +MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = os.getenv("MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN") +if not MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN: + MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + logger.warning(f"MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN is not set in .env, defaulting to {MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN}") +MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = int(MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN) + USE_CHUTES = os.getenv("USE_CHUTES") diff --git a/inference_gateway/error_hash_map.py b/inference_gateway/error_hash_map.py new file mode 100644 index 000000000..d56a3fcf5 --- /dev/null +++ b/inference_gateway/error_hash_map.py @@ -0,0 +1,50 @@ +# Tracks the number of non-halting (platform-side) inference errors per +# evaluation run. When the count exceeds a configured threshold the run +# is flagged as a platform error so the agent is not penalized unfairly. + +import time + +from uuid import UUID +from pydantic import BaseModel + + + +ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS = 60 # 1 minute + +class ErrorHashMapEntry(BaseModel): + inference_errors: int + last_accessed_at: float + +class ErrorHashMap: + def __init__(self): + self.error_hash_map = {} + self.last_cleanup_at = time.time() + + + + def _cleanup(self): + now = time.time() + if now - self.last_cleanup_at > ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS: + self.error_hash_map = {k: v for k, v in self.error_hash_map.items() if now - v.last_accessed_at < ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS} + self.last_cleanup_at = now + + + + def get_inference_errors(self, uuid: UUID) -> int: + self._cleanup() + + if uuid in self.error_hash_map: + self.error_hash_map[uuid].last_accessed_at = time.time() + return self.error_hash_map[uuid].inference_errors + else: + return 0 + + def add_inference_error(self, uuid: UUID): + self._cleanup() + + if uuid in self.error_hash_map: + entry = self.error_hash_map[uuid] + entry.inference_errors += 1 + entry.last_accessed_at = time.time() + else: + self.error_hash_map[uuid] = ErrorHashMapEntry(inference_errors=1, last_accessed_at=time.time()) diff --git a/inference_gateway/main.py b/inference_gateway/main.py index fcb97aa0e..cf737387c 100644 --- a/inference_gateway/main.py +++ b/inference_gateway/main.py @@ -12,6 +12,7 @@ from contextlib import asynccontextmanager from models.evaluation_run import EvaluationRunStatus from inference_gateway.cost_hash_map import CostHashMap +from inference_gateway.error_hash_map import ErrorHashMap from inference_gateway.providers.provider import Provider from inference_gateway.providers.chutes import ChutesProvider from inference_gateway.providers.targon import TargonProvider @@ -107,6 +108,18 @@ async def wrapper(*args, **kwargs): cost_hash_map = CostHashMap() +error_hash_map = ErrorHashMap() + + + +# Platform-side error codes that are not the agent's fault. 4xx errors like +# 400/404/422 are excluded because those indicate bad agent requests (wrong +# model, invalid format, etc). 429 is excluded because it means the cost +# limit was intentionally reached. +NON_HALTING_ERROR_CODES = {500, 502, 503, 504, -1} + +def is_non_halting_error(status_code: int) -> bool: + return status_code in NON_HALTING_ERROR_CODES @@ -158,7 +171,15 @@ async def inference(request: InferenceRequest) -> InferenceResponse: detail=f"The evaluation run with ID {request.evaluation_run_id} has reached or exceeded the evaluation run cost limit of {config.MAX_COST_PER_EVALUATION_RUN_USD} USD (current cost: {cost} USD)." ) - + # Make sure the evaluation run has not had too many platform-side inference errors + inference_errors = error_hash_map.get_inference_errors(request.evaluation_run_id) + if inference_errors >= config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN: + raise HTTPException( + status_code=503, + detail=f"The evaluation run with ID {request.evaluation_run_id} has had too many platform-side inference errors ({inference_errors} errors, limit is {config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})." + ) + + # Make sure we support the model for inference provider = get_provider_that_supports_model_for_inference(request.model) @@ -206,6 +227,10 @@ async def inference(request: InferenceRequest) -> InferenceResponse: tool_calls=response.tool_calls ) else: + # Track non-halting errors (platform/provider failures, not agent mistakes) + if is_non_halting_error(response.status_code): + error_hash_map.add_inference_error(request.evaluation_run_id) + raise HTTPException( status_code=response.status_code, detail=response.error_message @@ -245,7 +270,15 @@ async def embedding(request: EmbeddingRequest) -> EmbeddingResponse: detail=f"The evaluation run with ID {request.evaluation_run_id} has reached or exceeded the evaluation run cost limit of {config.MAX_COST_PER_EVALUATION_RUN_USD} USD (current cost: {cost} USD)." ) - + # Make sure the evaluation run has not had too many platform-side inference errors + inference_errors = error_hash_map.get_inference_errors(request.evaluation_run_id) + if inference_errors >= config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN: + raise HTTPException( + status_code=503, + detail=f"The evaluation run with ID {request.evaluation_run_id} has had too many platform-side inference errors ({inference_errors} errors, limit is {config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})." + ) + + # Make sure we support the model for embedding provider = get_provider_that_supports_model_for_embedding(request.model) @@ -287,6 +320,10 @@ async def embedding(request: EmbeddingRequest) -> EmbeddingResponse: embedding=response.embedding ) else: + # Track non-halting errors (platform/provider failures, not agent mistakes) + if is_non_halting_error(response.status_code): + error_hash_map.add_inference_error(request.evaluation_run_id) + raise HTTPException( status_code=response.status_code, detail=response.error_message @@ -298,6 +335,8 @@ class UsageResponse(BaseModel): used_cost_usd: float remaining_cost_usd: float max_cost_usd: float + inference_errors: int + max_inference_errors: int @app.get("/api/usage") @handle_http_exceptions @@ -306,7 +345,9 @@ async def usage(evaluation_run_id: UUID) -> UsageResponse: return UsageResponse( used_cost_usd=used_cost_usd, remaining_cost_usd=config.MAX_COST_PER_EVALUATION_RUN_USD - used_cost_usd, - max_cost_usd=config.MAX_COST_PER_EVALUATION_RUN_USD + max_cost_usd=config.MAX_COST_PER_EVALUATION_RUN_USD, + inference_errors=error_hash_map.get_inference_errors(evaluation_run_id), + max_inference_errors=config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN ) diff --git a/models/evaluation_run.py b/models/evaluation_run.py index 891c28f7c..5c70a641c 100644 --- a/models/evaluation_run.py +++ b/models/evaluation_run.py @@ -37,6 +37,7 @@ def __new__(cls, code: int, message: str): PLATFORM_RESTARTED_WHILE_RUNNING_AGENT = (3020, "The platform was restarted while the evaluation run was running the agent") PLATFORM_RESTARTED_WHILE_INIT_EVAL = (3030, "The platform was restarted while the evaluation run was initializing the evaluation") PLATFORM_RESTARTED_WHILE_RUNNING_EVAL = (3040, "The platform was restarted while the evaluation run was running the evaluation") + PLATFORM_TOO_MANY_INFERENCE_ERRORS = (3050, "Too many platform-side inference errors occurred during the evaluation run") def get_error_message(self) -> str: return self.message diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_inference_error_tracking.py b/tests/test_inference_error_tracking.py new file mode 100644 index 000000000..e1affae85 --- /dev/null +++ b/tests/test_inference_error_tracking.py @@ -0,0 +1,291 @@ +""" +Tests for platform-side inference error tracking. +Covers: + - ErrorHashMap error counting and cleanup + - Non-halting error code classification + - Inference gateway threshold enforcement via the /api/inference endpoint + - Usage endpoint reporting error counts +""" + +import os +import time +import pytest +from uuid import uuid4 + +from inference_gateway.error_hash_map import ErrorHashMap, ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS + + + +# --------------------------------------------------------------------------- +# Unit tests: ErrorHashMap +# --------------------------------------------------------------------------- + +class TestErrorHashMap: + def setup_method(self): + self.ehm = ErrorHashMap() + self.run_id = uuid4() + + def test_get_inference_errors_returns_zero_for_unknown_run(self): + assert self.ehm.get_inference_errors(uuid4()) == 0 + + def test_add_inference_error_creates_entry(self): + self.ehm.add_inference_error(self.run_id) + assert self.ehm.get_inference_errors(self.run_id) == 1 + + def test_add_inference_error_increments(self): + for _ in range(7): + self.ehm.add_inference_error(self.run_id) + assert self.ehm.get_inference_errors(self.run_id) == 7 + + def test_separate_runs_tracked_independently(self): + run_a = uuid4() + run_b = uuid4() + + self.ehm.add_inference_error(run_a) + self.ehm.add_inference_error(run_a) + self.ehm.add_inference_error(run_b) + + assert self.ehm.get_inference_errors(run_a) == 2 + assert self.ehm.get_inference_errors(run_b) == 1 + + def test_cleanup_removes_stale_entries(self): + self.ehm.add_inference_error(self.run_id) + assert self.ehm.get_inference_errors(self.run_id) == 1 + + # Simulate the entry going stale + self.ehm.error_hash_map[self.run_id].last_accessed_at = time.time() - ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 1 + self.ehm.last_cleanup_at = time.time() - ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 1 + + # Next access triggers cleanup, entry is gone + assert self.ehm.get_inference_errors(self.run_id) == 0 + + + +# --------------------------------------------------------------------------- +# Unit tests: non-halting error classification +# +# We define the expected set here to avoid importing inference_gateway.main, +# which triggers the config import chain and requires env vars. +# --------------------------------------------------------------------------- + +EXPECTED_NON_HALTING_ERROR_CODES = {500, 502, 503, 504, -1} + +class TestNonHaltingErrorClassification: + def test_server_errors_are_non_halting(self): + for code in [500, 502, 503, 504]: + assert code in EXPECTED_NON_HALTING_ERROR_CODES, f"Expected {code} to be non-halting" + + def test_internal_error_is_non_halting(self): + assert -1 in EXPECTED_NON_HALTING_ERROR_CODES + + def test_client_errors_are_halting(self): + for code in [400, 404, 422, 429]: + assert code not in EXPECTED_NON_HALTING_ERROR_CODES, f"Expected {code} to be halting" + + def test_success_is_not_non_halting(self): + assert 200 not in EXPECTED_NON_HALTING_ERROR_CODES + + + +# --------------------------------------------------------------------------- +# Unit tests: EvaluationRunErrorCode +# --------------------------------------------------------------------------- + +class TestEvaluationRunErrorCode: + def test_new_error_code_exists(self): + from models.evaluation_run import EvaluationRunErrorCode + code = EvaluationRunErrorCode.PLATFORM_TOO_MANY_INFERENCE_ERRORS + assert code.value == 3050 + assert code.is_platform_error() + assert not code.is_agent_error() + assert not code.is_validator_error() + + def test_error_message(self): + from models.evaluation_run import EvaluationRunErrorCode + msg = EvaluationRunErrorCode.PLATFORM_TOO_MANY_INFERENCE_ERRORS.get_error_message() + assert "inference errors" in msg.lower() + + + +# --------------------------------------------------------------------------- +# Integration tests: inference gateway endpoints +# +# These require the full app to be importable. They set minimal env vars +# to satisfy the config module, then mock the providers. +# --------------------------------------------------------------------------- + +def _set_minimal_gateway_env(): + """Set the bare minimum env vars so inference_gateway.config can load.""" + defaults = { + "HOST": "0.0.0.0", + "PORT": "9999", + "USE_DATABASE": "false", + "MAX_COST_PER_EVALUATION_RUN_USD": "10.0", + "MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN": "5", + "USE_CHUTES": "false", + "USE_TARGON": "false", + "USE_OPENROUTER": "false", + "TEST_INFERENCE_MODELS": "false", + "TEST_EMBEDDING_MODELS": "false", + } + for key, val in defaults.items(): + os.environ.setdefault(key, val) + + +# Guard: only define integration tests if we can import the app +_can_import_app = False +try: + _set_minimal_gateway_env() + + # The config fatals if no provider is enabled, so we need to patch + # the fatal check. We do this by temporarily setting one provider. + os.environ["USE_OPENROUTER"] = "true" + os.environ["OPENROUTER_BASE_URL"] = "http://localhost:9999" + os.environ["OPENROUTER_API_KEY"] = "test-key" + os.environ["OPENROUTER_WEIGHT"] = "1" + + from inference_gateway.main import app, cost_hash_map as global_cost_hash_map, error_hash_map as global_error_hash_map, is_non_halting_error, NON_HALTING_ERROR_CODES + from inference_gateway.models import InferenceResult + _can_import_app = True +except Exception: + pass + + +if _can_import_app: + import pytest_asyncio + from unittest.mock import AsyncMock, patch, MagicMock + from httpx import ASGITransport, AsyncClient + + @pytest_asyncio.fixture + async def client(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + def _mock_provider(status_code=200): + """Create a mock provider that returns the given status code.""" + provider = MagicMock() + provider.name = "MockProvider" + provider.is_model_supported_for_inference.return_value = True + + result = InferenceResult( + status_code=status_code, + content="hello" if status_code == 200 else None, + error_message="provider error" if status_code != 200 else None, + tool_calls=[], + num_input_tokens=10, + num_output_tokens=5, + cost_usd=0.001 + ) + provider.inference = AsyncMock(return_value=result) + return provider + + + @pytest.mark.asyncio + class TestInferenceGatewayErrorTracking: + def setup_method(self): + global_cost_hash_map.cost_hash_map = {} + global_cost_hash_map.last_cleanup_at = time.time() + global_error_hash_map.error_hash_map = {} + global_error_hash_map.last_cleanup_at = time.time() + + async def test_non_halting_error_increments_counter(self, client): + run_id = str(uuid4()) + mock_provider = _mock_provider(status_code=500) + + with patch("inference_gateway.main.get_provider_that_supports_model_for_inference", return_value=mock_provider), \ + patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = False + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + response = await client.post("/api/inference", json={ + "evaluation_run_id": run_id, + "model": "test-model", + "temperature": 0.5, + "messages": [{"role": "user", "content": "test"}] + }) + + assert response.status_code == 500 + from uuid import UUID + assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 1 + + async def test_halting_error_does_not_increment_counter(self, client): + run_id = str(uuid4()) + mock_provider = _mock_provider(status_code=422) + + with patch("inference_gateway.main.get_provider_that_supports_model_for_inference", return_value=mock_provider), \ + patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = False + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + response = await client.post("/api/inference", json={ + "evaluation_run_id": run_id, + "model": "test-model", + "temperature": 0.5, + "messages": [{"role": "user", "content": "test"}] + }) + + assert response.status_code == 422 + from uuid import UUID + assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 0 + + async def test_threshold_blocks_further_requests(self, client): + run_id = str(uuid4()) + from uuid import UUID + run_uuid = UUID(run_id) + + # Pre-fill the error count to the limit + for _ in range(5): + global_error_hash_map.add_inference_error(run_uuid) + + mock_provider = _mock_provider(status_code=200) + + with patch("inference_gateway.main.get_provider_that_supports_model_for_inference", return_value=mock_provider), \ + patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = True + mock_config.CHECK_EVALUATION_RUNS = True + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + with patch("inference_gateway.main.get_evaluation_run_status_by_id", new_callable=AsyncMock) as mock_status: + from models.evaluation_run import EvaluationRunStatus + mock_status.return_value = EvaluationRunStatus.running_agent + + response = await client.post("/api/inference", json={ + "evaluation_run_id": run_id, + "model": "test-model", + "temperature": 0.5, + "messages": [{"role": "user", "content": "test"}] + }) + + assert response.status_code == 503 + assert "too many platform-side inference errors" in response.json()["detail"].lower() + + async def test_usage_endpoint_reports_errors(self, client): + run_id = str(uuid4()) + from uuid import UUID + run_uuid = UUID(run_id) + + global_error_hash_map.add_inference_error(run_uuid) + global_error_hash_map.add_inference_error(run_uuid) + global_cost_hash_map.add_cost(run_uuid, 0.05) + + with patch("inference_gateway.main.config") as mock_config: + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + response = await client.get(f"/api/usage?evaluation_run_id={run_id}") + + assert response.status_code == 200 + data = response.json() + assert data["inference_errors"] == 2 + assert data["max_inference_errors"] == 5 + assert data["used_cost_usd"] == 0.05 + + async def test_constants_match_expected(self): + """Verify the actual constants match what we test against.""" + assert NON_HALTING_ERROR_CODES == EXPECTED_NON_HALTING_ERROR_CODES + assert is_non_halting_error(500) + assert not is_non_halting_error(400) diff --git a/validator/main.py b/validator/main.py index d5f84ab6a..29ee8039e 100644 --- a/validator/main.py +++ b/validator/main.py @@ -202,6 +202,28 @@ async def _run_evaluation_run(evaluation_run_id: UUID, problem_name: str, agent_ ) logger.info(f"Finished running agent for problem {problem_name}: {len(patch.splitlines())} lines of patch, {len(agent_logs.splitlines())} lines of agent logs") + # Check if the agent was affected by platform-side inference errors. + # If the inference gateway saw too many non-halting errors for this + # run, the agent never had a fair chance, so we bail out early and + # mark this as a platform error instead of scoring a bad patch. + try: + async with httpx.AsyncClient() as client: + usage_response = await client.get(f"{config.RIDGES_INFERENCE_GATEWAY_URL}/api/usage?evaluation_run_id={evaluation_run_id}") + if usage_response.status_code == 200: + usage = usage_response.json() + inference_errors = usage.get("inference_errors", 0) + max_inference_errors = usage.get("max_inference_errors", float("inf")) + if inference_errors >= max_inference_errors: + raise EvaluationRunException( + EvaluationRunErrorCode.PLATFORM_TOO_MANY_INFERENCE_ERRORS, + f"{EvaluationRunErrorCode.PLATFORM_TOO_MANY_INFERENCE_ERRORS.get_error_message()}: {inference_errors} inference errors (limit: {max_inference_errors})", + extra={"agent_logs": truncate_logs_if_required(agent_logs)} + ) + except EvaluationRunException: + raise + except Exception as e: + logger.warning(f"Failed to check inference error count for evaluation run {evaluation_run_id}: {e}") + # Move from running_agent -> initializing_eval await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.initializing_eval, { "patch": patch, @@ -243,7 +265,8 @@ async def _run_evaluation_run(evaluation_run_id: UUID, problem_name: str, agent_ await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.error, { "error_code": e.error_code.value, - "error_message": e.error_message + "error_message": e.error_message, + **(e.extra or {}) }) except Exception as e: From cee757ced84af1daeca40600075e18f897f28dea Mon Sep 17 00:00:00 2001 From: statxc <181730535+statxc@users.noreply.github.com> Date: Thu, 12 Mar 2026 14:11:29 +0000 Subject: [PATCH 2/3] refactor: update halting name to platform error --- inference_gateway/main.py | 10 ++++---- tests/test_inference_error_tracking.py | 34 +++++++++++++------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/inference_gateway/main.py b/inference_gateway/main.py index cf737387c..813b8486b 100644 --- a/inference_gateway/main.py +++ b/inference_gateway/main.py @@ -116,10 +116,10 @@ async def wrapper(*args, **kwargs): # 400/404/422 are excluded because those indicate bad agent requests (wrong # model, invalid format, etc). 429 is excluded because it means the cost # limit was intentionally reached. -NON_HALTING_ERROR_CODES = {500, 502, 503, 504, -1} +PLATFORM_ERROR_CODES = {500, 502, 503, 504, -1} -def is_non_halting_error(status_code: int) -> bool: - return status_code in NON_HALTING_ERROR_CODES +def is_platform_error(status_code: int) -> bool: + return status_code in PLATFORM_ERROR_CODES @@ -228,7 +228,7 @@ async def inference(request: InferenceRequest) -> InferenceResponse: ) else: # Track non-halting errors (platform/provider failures, not agent mistakes) - if is_non_halting_error(response.status_code): + if is_platform_error(response.status_code): error_hash_map.add_inference_error(request.evaluation_run_id) raise HTTPException( @@ -321,7 +321,7 @@ async def embedding(request: EmbeddingRequest) -> EmbeddingResponse: ) else: # Track non-halting errors (platform/provider failures, not agent mistakes) - if is_non_halting_error(response.status_code): + if is_platform_error(response.status_code): error_hash_map.add_inference_error(request.evaluation_run_id) raise HTTPException( diff --git a/tests/test_inference_error_tracking.py b/tests/test_inference_error_tracking.py index e1affae85..5c4eddcbc 100644 --- a/tests/test_inference_error_tracking.py +++ b/tests/test_inference_error_tracking.py @@ -62,28 +62,28 @@ def test_cleanup_removes_stale_entries(self): # --------------------------------------------------------------------------- -# Unit tests: non-halting error classification +# Unit tests: platform error classification # # We define the expected set here to avoid importing inference_gateway.main, # which triggers the config import chain and requires env vars. # --------------------------------------------------------------------------- -EXPECTED_NON_HALTING_ERROR_CODES = {500, 502, 503, 504, -1} +EXPECTED_PLATFORM_ERROR_CODES = {500, 502, 503, 504, -1} -class TestNonHaltingErrorClassification: - def test_server_errors_are_non_halting(self): +class TestPlatformErrorClassification: + def test_server_errors_are_platform_errors(self): for code in [500, 502, 503, 504]: - assert code in EXPECTED_NON_HALTING_ERROR_CODES, f"Expected {code} to be non-halting" + assert code in EXPECTED_PLATFORM_ERROR_CODES, f"Expected {code} to be a platform error" - def test_internal_error_is_non_halting(self): - assert -1 in EXPECTED_NON_HALTING_ERROR_CODES + def test_internal_error_is_platform_error(self): + assert -1 in EXPECTED_PLATFORM_ERROR_CODES - def test_client_errors_are_halting(self): + def test_client_errors_are_not_platform_errors(self): for code in [400, 404, 422, 429]: - assert code not in EXPECTED_NON_HALTING_ERROR_CODES, f"Expected {code} to be halting" + assert code not in EXPECTED_PLATFORM_ERROR_CODES, f"Expected {code} to not be a platform error" - def test_success_is_not_non_halting(self): - assert 200 not in EXPECTED_NON_HALTING_ERROR_CODES + def test_success_is_not_platform_error(self): + assert 200 not in EXPECTED_PLATFORM_ERROR_CODES @@ -144,7 +144,7 @@ def _set_minimal_gateway_env(): os.environ["OPENROUTER_API_KEY"] = "test-key" os.environ["OPENROUTER_WEIGHT"] = "1" - from inference_gateway.main import app, cost_hash_map as global_cost_hash_map, error_hash_map as global_error_hash_map, is_non_halting_error, NON_HALTING_ERROR_CODES + from inference_gateway.main import app, cost_hash_map as global_cost_hash_map, error_hash_map as global_error_hash_map, is_platform_error, PLATFORM_ERROR_CODES from inference_gateway.models import InferenceResult _can_import_app = True except Exception: @@ -189,7 +189,7 @@ def setup_method(self): global_error_hash_map.error_hash_map = {} global_error_hash_map.last_cleanup_at = time.time() - async def test_non_halting_error_increments_counter(self, client): + async def test_platform_error_increments_counter(self, client): run_id = str(uuid4()) mock_provider = _mock_provider(status_code=500) @@ -210,7 +210,7 @@ async def test_non_halting_error_increments_counter(self, client): from uuid import UUID assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 1 - async def test_halting_error_does_not_increment_counter(self, client): + async def test_non_platform_error_does_not_increment_counter(self, client): run_id = str(uuid4()) mock_provider = _mock_provider(status_code=422) @@ -286,6 +286,6 @@ async def test_usage_endpoint_reports_errors(self, client): async def test_constants_match_expected(self): """Verify the actual constants match what we test against.""" - assert NON_HALTING_ERROR_CODES == EXPECTED_NON_HALTING_ERROR_CODES - assert is_non_halting_error(500) - assert not is_non_halting_error(400) + assert PLATFORM_ERROR_CODES == EXPECTED_PLATFORM_ERROR_CODES + assert is_platform_error(500) + assert not is_platform_error(400) From 57d58897b5912d1c6e2243379a2fa02549cf6f5f Mon Sep 17 00:00:00 2001 From: statxc <181730535+statxc@users.noreply.github.com> Date: Thu, 12 Mar 2026 15:02:57 +0000 Subject: [PATCH 3/3] refactor: Rename to platform error, add logging and httpx timeout, add embedding and edge case tests --- inference_gateway/error_hash_map.py | 6 +- inference_gateway/main.py | 10 ++- tests/test_inference_error_tracking.py | 88 ++++++++++++++++++++++++-- validator/main.py | 2 +- 4 files changed, 96 insertions(+), 10 deletions(-) diff --git a/inference_gateway/error_hash_map.py b/inference_gateway/error_hash_map.py index d56a3fcf5..06078f3ec 100644 --- a/inference_gateway/error_hash_map.py +++ b/inference_gateway/error_hash_map.py @@ -1,6 +1,6 @@ -# Tracks the number of non-halting (platform-side) inference errors per -# evaluation run. When the count exceeds a configured threshold the run -# is flagged as a platform error so the agent is not penalized unfairly. +# Tracks the number of platform-side inference errors per evaluation run. +# When the count exceeds a configured threshold the run is flagged as a +# platform error so the agent is not penalized unfairly. import time diff --git a/inference_gateway/main.py b/inference_gateway/main.py index 813b8486b..e17c359aa 100644 --- a/inference_gateway/main.py +++ b/inference_gateway/main.py @@ -174,6 +174,7 @@ async def inference(request: InferenceRequest) -> InferenceResponse: # Make sure the evaluation run has not had too many platform-side inference errors inference_errors = error_hash_map.get_inference_errors(request.evaluation_run_id) if inference_errors >= config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN: + logger.warning(f"Blocking inference for run {request.evaluation_run_id}: too many platform errors ({inference_errors}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})") raise HTTPException( status_code=503, detail=f"The evaluation run with ID {request.evaluation_run_id} has had too many platform-side inference errors ({inference_errors} errors, limit is {config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})." @@ -227,9 +228,11 @@ async def inference(request: InferenceRequest) -> InferenceResponse: tool_calls=response.tool_calls ) else: - # Track non-halting errors (platform/provider failures, not agent mistakes) + # Track platform errors (provider failures, not agent mistakes) if is_platform_error(response.status_code): error_hash_map.add_inference_error(request.evaluation_run_id) + error_count = error_hash_map.get_inference_errors(request.evaluation_run_id) + logger.warning(f"Platform inference error for run {request.evaluation_run_id}: status {response.status_code} (error {error_count}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})") raise HTTPException( status_code=response.status_code, @@ -273,6 +276,7 @@ async def embedding(request: EmbeddingRequest) -> EmbeddingResponse: # Make sure the evaluation run has not had too many platform-side inference errors inference_errors = error_hash_map.get_inference_errors(request.evaluation_run_id) if inference_errors >= config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN: + logger.warning(f"Blocking embedding for run {request.evaluation_run_id}: too many platform errors ({inference_errors}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})") raise HTTPException( status_code=503, detail=f"The evaluation run with ID {request.evaluation_run_id} has had too many platform-side inference errors ({inference_errors} errors, limit is {config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})." @@ -320,9 +324,11 @@ async def embedding(request: EmbeddingRequest) -> EmbeddingResponse: embedding=response.embedding ) else: - # Track non-halting errors (platform/provider failures, not agent mistakes) + # Track platform errors (provider failures, not agent mistakes) if is_platform_error(response.status_code): error_hash_map.add_inference_error(request.evaluation_run_id) + error_count = error_hash_map.get_inference_errors(request.evaluation_run_id) + logger.warning(f"Platform embedding error for run {request.evaluation_run_id}: status {response.status_code} (error {error_count}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})") raise HTTPException( status_code=response.status_code, diff --git a/tests/test_inference_error_tracking.py b/tests/test_inference_error_tracking.py index 5c4eddcbc..f09766e05 100644 --- a/tests/test_inference_error_tracking.py +++ b/tests/test_inference_error_tracking.py @@ -145,7 +145,7 @@ def _set_minimal_gateway_env(): os.environ["OPENROUTER_WEIGHT"] = "1" from inference_gateway.main import app, cost_hash_map as global_cost_hash_map, error_hash_map as global_error_hash_map, is_platform_error, PLATFORM_ERROR_CODES - from inference_gateway.models import InferenceResult + from inference_gateway.models import InferenceResult, EmbeddingResult _can_import_app = True except Exception: pass @@ -162,13 +162,14 @@ async def client(): async with AsyncClient(transport=transport, base_url="http://test") as c: yield c - def _mock_provider(status_code=200): + def _mock_provider(status_code=200, for_embedding=False): """Create a mock provider that returns the given status code.""" provider = MagicMock() provider.name = "MockProvider" provider.is_model_supported_for_inference.return_value = True + provider.is_model_supported_for_embedding.return_value = True - result = InferenceResult( + inference_result = InferenceResult( status_code=status_code, content="hello" if status_code == 200 else None, error_message="provider error" if status_code != 200 else None, @@ -177,7 +178,17 @@ def _mock_provider(status_code=200): num_output_tokens=5, cost_usd=0.001 ) - provider.inference = AsyncMock(return_value=result) + provider.inference = AsyncMock(return_value=inference_result) + + embedding_result = EmbeddingResult( + status_code=status_code, + embedding=[0.1, 0.2, 0.3] if status_code == 200 else None, + error_message="provider error" if status_code != 200 else None, + num_input_tokens=10, + cost_usd=0.0005 + ) + provider.embedding = AsyncMock(return_value=embedding_result) + return provider @@ -284,6 +295,75 @@ async def test_usage_endpoint_reports_errors(self, client): assert data["max_inference_errors"] == 5 assert data["used_cost_usd"] == 0.05 + async def test_usage_endpoint_zero_errors(self, client): + """A run with no errors should report zero.""" + run_id = str(uuid4()) + + with patch("inference_gateway.main.config") as mock_config: + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + response = await client.get(f"/api/usage?evaluation_run_id={run_id}") + + assert response.status_code == 200 + data = response.json() + assert data["inference_errors"] == 0 + assert data["used_cost_usd"] == 0.0 + + async def test_separate_runs_do_not_interfere(self, client): + """Errors for one run must not affect another run's count.""" + run_a = str(uuid4()) + run_b = str(uuid4()) + mock_provider = _mock_provider(status_code=502) + + with patch("inference_gateway.main.get_provider_that_supports_model_for_inference", return_value=mock_provider), \ + patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = False + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + # Hit run_a three times + for _ in range(3): + await client.post("/api/inference", json={ + "evaluation_run_id": run_a, + "model": "test-model", + "temperature": 0.5, + "messages": [{"role": "user", "content": "test"}] + }) + + # Hit run_b once + await client.post("/api/inference", json={ + "evaluation_run_id": run_b, + "model": "test-model", + "temperature": 0.5, + "messages": [{"role": "user", "content": "test"}] + }) + + from uuid import UUID + assert global_error_hash_map.get_inference_errors(UUID(run_a)) == 3 + assert global_error_hash_map.get_inference_errors(UUID(run_b)) == 1 + + async def test_embedding_platform_error_increments_counter(self, client): + """Embedding endpoint should also track platform errors.""" + run_id = str(uuid4()) + mock_provider = _mock_provider(status_code=503) + + with patch("inference_gateway.main.get_provider_that_supports_model_for_embedding", return_value=mock_provider), \ + patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = False + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + response = await client.post("/api/embedding", json={ + "evaluation_run_id": run_id, + "model": "test-model", + "input": "hello world" + }) + + assert response.status_code == 503 + from uuid import UUID + assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 1 + async def test_constants_match_expected(self): """Verify the actual constants match what we test against.""" assert PLATFORM_ERROR_CODES == EXPECTED_PLATFORM_ERROR_CODES diff --git a/validator/main.py b/validator/main.py index 29ee8039e..45530c649 100644 --- a/validator/main.py +++ b/validator/main.py @@ -207,7 +207,7 @@ async def _run_evaluation_run(evaluation_run_id: UUID, problem_name: str, agent_ # run, the agent never had a fair chance, so we bail out early and # mark this as a platform error instead of scoring a bad patch. try: - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(timeout=10.0) as client: usage_response = await client.get(f"{config.RIDGES_INFERENCE_GATEWAY_URL}/api/usage?evaluation_run_id={evaluation_run_id}") if usage_response.status_code == 200: usage = usage_response.json()