diff --git a/inference_gateway/config.py b/inference_gateway/config.py index 77ac3b658..c93d9bd92 100644 --- a/inference_gateway/config.py +++ b/inference_gateway/config.py @@ -59,6 +59,12 @@ logger.fatal("MAX_COST_PER_EVALUATION_RUN_USD is not set in .env") MAX_COST_PER_EVALUATION_RUN_USD = float(MAX_COST_PER_EVALUATION_RUN_USD) +MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = os.getenv("MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN") +if not MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN: + MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + logger.warning(f"MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN is not set in .env, defaulting to {MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN}") +MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = int(MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN) + USE_CHUTES = os.getenv("USE_CHUTES") diff --git a/inference_gateway/error_hash_map.py b/inference_gateway/error_hash_map.py new file mode 100644 index 000000000..84ac9d10b --- /dev/null +++ b/inference_gateway/error_hash_map.py @@ -0,0 +1,55 @@ +# Tracks the number of platform-side inference errors per evaluation run. +# When the count exceeds a configured threshold the run is flagged as a +# platform error so the agent is not penalized unfairly. + +import time + +from uuid import UUID +from pydantic import BaseModel + + + +ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS = 60 # 1 minute + +class ErrorHashMapEntry(BaseModel): + inference_errors: int + last_accessed_at: float + +class ErrorHashMap: + def __init__(self): + self.error_hash_map = {} + self.last_cleanup_at = time.time() + + + + def _cleanup(self): + now = time.time() + if now - self.last_cleanup_at > ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS: + self.error_hash_map = {k: v for k, v in self.error_hash_map.items() if now - v.last_accessed_at < ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS} + self.last_cleanup_at = now + + + + def get_inference_errors(self, uuid: UUID) -> int: + self._cleanup() + + if uuid in self.error_hash_map: + self.error_hash_map[uuid].last_accessed_at = time.time() + return self.error_hash_map[uuid].inference_errors + else: + return 0 + + def add_inference_error(self, uuid: UUID): + self._cleanup() + + if uuid in self.error_hash_map: + entry = self.error_hash_map[uuid] + entry.inference_errors += 1 + entry.last_accessed_at = time.time() + else: + self.error_hash_map[uuid] = ErrorHashMapEntry(inference_errors=1, last_accessed_at=time.time()) + + def reset_inference_errors(self, uuid: UUID): + """Reset the error count for a specific evaluation run (used before retrying).""" + if uuid in self.error_hash_map: + del self.error_hash_map[uuid] diff --git a/inference_gateway/main.py b/inference_gateway/main.py index fcb97aa0e..48e0e88cb 100644 --- a/inference_gateway/main.py +++ b/inference_gateway/main.py @@ -12,6 +12,7 @@ from contextlib import asynccontextmanager from models.evaluation_run import EvaluationRunStatus from inference_gateway.cost_hash_map import CostHashMap +from inference_gateway.error_hash_map import ErrorHashMap from inference_gateway.providers.provider import Provider from inference_gateway.providers.chutes import ChutesProvider from inference_gateway.providers.targon import TargonProvider @@ -107,6 +108,18 @@ async def wrapper(*args, **kwargs): cost_hash_map = CostHashMap() +error_hash_map = ErrorHashMap() + + + +# Platform-side error codes that are not the agent's fault. 4xx errors like +# 400/404/422 are excluded because those indicate bad agent requests (wrong +# model, invalid format, etc). 429 is excluded because it means the cost +# limit was intentionally reached. +PLATFORM_ERROR_CODES = {500, 502, 503, 504, -1} + +def is_platform_error(status_code: int) -> bool: + return status_code in PLATFORM_ERROR_CODES @@ -158,7 +171,16 @@ async def inference(request: InferenceRequest) -> InferenceResponse: detail=f"The evaluation run with ID {request.evaluation_run_id} has reached or exceeded the evaluation run cost limit of {config.MAX_COST_PER_EVALUATION_RUN_USD} USD (current cost: {cost} USD)." ) - + # Make sure the evaluation run has not had too many platform-side inference errors + inference_errors = error_hash_map.get_inference_errors(request.evaluation_run_id) + if inference_errors >= config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN: + logger.warning(f"Blocking inference for run {request.evaluation_run_id}: too many platform errors ({inference_errors}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})") + raise HTTPException( + status_code=503, + detail=f"The evaluation run with ID {request.evaluation_run_id} has had too many platform-side inference errors ({inference_errors} errors, limit is {config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})." + ) + + # Make sure we support the model for inference provider = get_provider_that_supports_model_for_inference(request.model) @@ -206,6 +228,12 @@ async def inference(request: InferenceRequest) -> InferenceResponse: tool_calls=response.tool_calls ) else: + # Track platform errors (provider failures, not agent mistakes) + if is_platform_error(response.status_code): + error_hash_map.add_inference_error(request.evaluation_run_id) + error_count = error_hash_map.get_inference_errors(request.evaluation_run_id) + logger.warning(f"Platform inference error for run {request.evaluation_run_id}: status {response.status_code} (error {error_count}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})") + raise HTTPException( status_code=response.status_code, detail=response.error_message @@ -245,7 +273,16 @@ async def embedding(request: EmbeddingRequest) -> EmbeddingResponse: detail=f"The evaluation run with ID {request.evaluation_run_id} has reached or exceeded the evaluation run cost limit of {config.MAX_COST_PER_EVALUATION_RUN_USD} USD (current cost: {cost} USD)." ) - + # Make sure the evaluation run has not had too many platform-side inference errors + inference_errors = error_hash_map.get_inference_errors(request.evaluation_run_id) + if inference_errors >= config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN: + logger.warning(f"Blocking embedding for run {request.evaluation_run_id}: too many platform errors ({inference_errors}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})") + raise HTTPException( + status_code=503, + detail=f"The evaluation run with ID {request.evaluation_run_id} has had too many platform-side inference errors ({inference_errors} errors, limit is {config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})." + ) + + # Make sure we support the model for embedding provider = get_provider_that_supports_model_for_embedding(request.model) @@ -287,6 +324,12 @@ async def embedding(request: EmbeddingRequest) -> EmbeddingResponse: embedding=response.embedding ) else: + # Track platform errors (provider failures, not agent mistakes) + if is_platform_error(response.status_code): + error_hash_map.add_inference_error(request.evaluation_run_id) + error_count = error_hash_map.get_inference_errors(request.evaluation_run_id) + logger.warning(f"Platform embedding error for run {request.evaluation_run_id}: status {response.status_code} (error {error_count}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})") + raise HTTPException( status_code=response.status_code, detail=response.error_message @@ -298,6 +341,8 @@ class UsageResponse(BaseModel): used_cost_usd: float remaining_cost_usd: float max_cost_usd: float + inference_errors: int + max_inference_errors: int @app.get("/api/usage") @handle_http_exceptions @@ -306,9 +351,18 @@ async def usage(evaluation_run_id: UUID) -> UsageResponse: return UsageResponse( used_cost_usd=used_cost_usd, remaining_cost_usd=config.MAX_COST_PER_EVALUATION_RUN_USD - used_cost_usd, - max_cost_usd=config.MAX_COST_PER_EVALUATION_RUN_USD + max_cost_usd=config.MAX_COST_PER_EVALUATION_RUN_USD, + inference_errors=error_hash_map.get_inference_errors(evaluation_run_id), + max_inference_errors=config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN ) +@app.post("/api/reset-inference-errors") +@handle_http_exceptions +async def reset_inference_errors(evaluation_run_id: UUID): + """Reset the inference error count for a specific evaluation run before retrying.""" + error_hash_map.reset_inference_errors(evaluation_run_id) + return {"status": "ok"} + @app.get("/api/inference-models") diff --git a/models/evaluation_run.py b/models/evaluation_run.py index 891c28f7c..5c70a641c 100644 --- a/models/evaluation_run.py +++ b/models/evaluation_run.py @@ -37,6 +37,7 @@ def __new__(cls, code: int, message: str): PLATFORM_RESTARTED_WHILE_RUNNING_AGENT = (3020, "The platform was restarted while the evaluation run was running the agent") PLATFORM_RESTARTED_WHILE_INIT_EVAL = (3030, "The platform was restarted while the evaluation run was initializing the evaluation") PLATFORM_RESTARTED_WHILE_RUNNING_EVAL = (3040, "The platform was restarted while the evaluation run was running the evaluation") + PLATFORM_TOO_MANY_INFERENCE_ERRORS = (3050, "Too many platform-side inference errors occurred during the evaluation run") def get_error_message(self) -> str: return self.message diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_inference_error_tracking.py b/tests/test_inference_error_tracking.py new file mode 100644 index 000000000..e1f9200ae --- /dev/null +++ b/tests/test_inference_error_tracking.py @@ -0,0 +1,551 @@ +""" +Tests for platform-side inference error tracking. +Covers: + - ErrorHashMap error counting and cleanup + - Non-halting error code classification + - Inference gateway threshold enforcement via the /api/inference endpoint + - Usage endpoint reporting error counts +""" + +import os +import time +import pytest +from uuid import uuid4 + +from inference_gateway.error_hash_map import ErrorHashMap, ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS + + + +# --------------------------------------------------------------------------- +# Unit tests: ErrorHashMap +# --------------------------------------------------------------------------- + +class TestErrorHashMap: + def setup_method(self): + self.ehm = ErrorHashMap() + self.run_id = uuid4() + + def test_get_inference_errors_returns_zero_for_unknown_run(self): + assert self.ehm.get_inference_errors(uuid4()) == 0 + + def test_add_inference_error_creates_entry(self): + self.ehm.add_inference_error(self.run_id) + assert self.ehm.get_inference_errors(self.run_id) == 1 + + def test_add_inference_error_increments(self): + for _ in range(7): + self.ehm.add_inference_error(self.run_id) + assert self.ehm.get_inference_errors(self.run_id) == 7 + + def test_separate_runs_tracked_independently(self): + run_a = uuid4() + run_b = uuid4() + + self.ehm.add_inference_error(run_a) + self.ehm.add_inference_error(run_a) + self.ehm.add_inference_error(run_b) + + assert self.ehm.get_inference_errors(run_a) == 2 + assert self.ehm.get_inference_errors(run_b) == 1 + + def test_cleanup_removes_stale_entries(self): + self.ehm.add_inference_error(self.run_id) + assert self.ehm.get_inference_errors(self.run_id) == 1 + + # Simulate the entry going stale + self.ehm.error_hash_map[self.run_id].last_accessed_at = time.time() - ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 1 + self.ehm.last_cleanup_at = time.time() - ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS - 1 + + # Next access triggers cleanup, entry is gone + assert self.ehm.get_inference_errors(self.run_id) == 0 + + def test_reset_inference_errors_clears_count(self): + for _ in range(5): + self.ehm.add_inference_error(self.run_id) + assert self.ehm.get_inference_errors(self.run_id) == 5 + + self.ehm.reset_inference_errors(self.run_id) + assert self.ehm.get_inference_errors(self.run_id) == 0 + + def test_reset_inference_errors_noop_for_unknown_run(self): + # Should not raise + self.ehm.reset_inference_errors(uuid4()) + + def test_reset_does_not_affect_other_runs(self): + run_a = uuid4() + run_b = uuid4() + self.ehm.add_inference_error(run_a) + self.ehm.add_inference_error(run_a) + self.ehm.add_inference_error(run_b) + + self.ehm.reset_inference_errors(run_a) + assert self.ehm.get_inference_errors(run_a) == 0 + assert self.ehm.get_inference_errors(run_b) == 1 + + + +# --------------------------------------------------------------------------- +# Unit tests: platform error classification +# +# We define the expected set here to avoid importing inference_gateway.main, +# which triggers the config import chain and requires env vars. +# --------------------------------------------------------------------------- + +EXPECTED_PLATFORM_ERROR_CODES = {500, 502, 503, 504, -1} + +class TestPlatformErrorClassification: + def test_server_errors_are_platform_errors(self): + for code in [500, 502, 503, 504]: + assert code in EXPECTED_PLATFORM_ERROR_CODES, f"Expected {code} to be a platform error" + + def test_internal_error_is_platform_error(self): + assert -1 in EXPECTED_PLATFORM_ERROR_CODES + + def test_client_errors_are_not_platform_errors(self): + for code in [400, 404, 422, 429]: + assert code not in EXPECTED_PLATFORM_ERROR_CODES, f"Expected {code} to not be a platform error" + + def test_success_is_not_platform_error(self): + assert 200 not in EXPECTED_PLATFORM_ERROR_CODES + + + +# --------------------------------------------------------------------------- +# Unit tests: EvaluationRunErrorCode +# --------------------------------------------------------------------------- + +class TestEvaluationRunErrorCode: + def test_new_error_code_exists(self): + from models.evaluation_run import EvaluationRunErrorCode + code = EvaluationRunErrorCode.PLATFORM_TOO_MANY_INFERENCE_ERRORS + assert code.value == 3050 + assert code.is_platform_error() + assert not code.is_agent_error() + assert not code.is_validator_error() + + def test_error_message(self): + from models.evaluation_run import EvaluationRunErrorCode + msg = EvaluationRunErrorCode.PLATFORM_TOO_MANY_INFERENCE_ERRORS.get_error_message() + assert "inference errors" in msg.lower() + + + +# --------------------------------------------------------------------------- +# Integration tests: inference gateway endpoints +# +# These require the full app to be importable. They set minimal env vars +# to satisfy the config module, then mock the providers. +# --------------------------------------------------------------------------- + +def _set_minimal_gateway_env(): + """Set the bare minimum env vars so inference_gateway.config can load.""" + defaults = { + "HOST": "0.0.0.0", + "PORT": "9999", + "USE_DATABASE": "false", + "MAX_COST_PER_EVALUATION_RUN_USD": "10.0", + "MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN": "5", + "USE_CHUTES": "false", + "USE_TARGON": "false", + "USE_OPENROUTER": "false", + "TEST_INFERENCE_MODELS": "false", + "TEST_EMBEDDING_MODELS": "false", + } + for key, val in defaults.items(): + os.environ.setdefault(key, val) + + +# Guard: only define integration tests if we can import the app +_can_import_app = False +try: + _set_minimal_gateway_env() + + # The config fatals if no provider is enabled, so we need to patch + # the fatal check. We do this by temporarily setting one provider. + os.environ["USE_OPENROUTER"] = "true" + os.environ["OPENROUTER_BASE_URL"] = "http://localhost:9999" + os.environ["OPENROUTER_API_KEY"] = "test-key" + os.environ["OPENROUTER_WEIGHT"] = "1" + + from inference_gateway.main import app, cost_hash_map as global_cost_hash_map, error_hash_map as global_error_hash_map, is_platform_error, PLATFORM_ERROR_CODES + from inference_gateway.models import InferenceResult, EmbeddingResult + _can_import_app = True +except Exception: + pass + + +if _can_import_app: + import pytest_asyncio + from unittest.mock import AsyncMock, patch, MagicMock + from httpx import ASGITransport, AsyncClient + + @pytest_asyncio.fixture + async def client(): + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + def _mock_provider(status_code=200, for_embedding=False): + """Create a mock provider that returns the given status code.""" + provider = MagicMock() + provider.name = "MockProvider" + provider.is_model_supported_for_inference.return_value = True + provider.is_model_supported_for_embedding.return_value = True + + inference_result = InferenceResult( + status_code=status_code, + content="hello" if status_code == 200 else None, + error_message="provider error" if status_code != 200 else None, + tool_calls=[], + num_input_tokens=10, + num_output_tokens=5, + cost_usd=0.001 + ) + provider.inference = AsyncMock(return_value=inference_result) + + embedding_result = EmbeddingResult( + status_code=status_code, + embedding=[0.1, 0.2, 0.3] if status_code == 200 else None, + error_message="provider error" if status_code != 200 else None, + num_input_tokens=10, + cost_usd=0.0005 + ) + provider.embedding = AsyncMock(return_value=embedding_result) + + return provider + + + @pytest.mark.asyncio + class TestInferenceGatewayErrorTracking: + def setup_method(self): + global_cost_hash_map.cost_hash_map = {} + global_cost_hash_map.last_cleanup_at = time.time() + global_error_hash_map.error_hash_map = {} + global_error_hash_map.last_cleanup_at = time.time() + + async def test_platform_error_increments_counter(self, client): + run_id = str(uuid4()) + mock_provider = _mock_provider(status_code=500) + + with patch("inference_gateway.main.get_provider_that_supports_model_for_inference", return_value=mock_provider), \ + patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = False + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + response = await client.post("/api/inference", json={ + "evaluation_run_id": run_id, + "model": "test-model", + "temperature": 0.5, + "messages": [{"role": "user", "content": "test"}] + }) + + assert response.status_code == 500 + from uuid import UUID + assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 1 + + async def test_non_platform_error_does_not_increment_counter(self, client): + run_id = str(uuid4()) + mock_provider = _mock_provider(status_code=422) + + with patch("inference_gateway.main.get_provider_that_supports_model_for_inference", return_value=mock_provider), \ + patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = False + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + response = await client.post("/api/inference", json={ + "evaluation_run_id": run_id, + "model": "test-model", + "temperature": 0.5, + "messages": [{"role": "user", "content": "test"}] + }) + + assert response.status_code == 422 + from uuid import UUID + assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 0 + + async def test_threshold_blocks_further_requests(self, client): + run_id = str(uuid4()) + from uuid import UUID + run_uuid = UUID(run_id) + + # Pre-fill the error count to the limit + for _ in range(5): + global_error_hash_map.add_inference_error(run_uuid) + + mock_provider = _mock_provider(status_code=200) + + with patch("inference_gateway.main.get_provider_that_supports_model_for_inference", return_value=mock_provider), \ + patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = True + mock_config.CHECK_EVALUATION_RUNS = True + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + with patch("inference_gateway.main.get_evaluation_run_status_by_id", new_callable=AsyncMock) as mock_status: + from models.evaluation_run import EvaluationRunStatus + mock_status.return_value = EvaluationRunStatus.running_agent + + response = await client.post("/api/inference", json={ + "evaluation_run_id": run_id, + "model": "test-model", + "temperature": 0.5, + "messages": [{"role": "user", "content": "test"}] + }) + + assert response.status_code == 503 + assert "too many platform-side inference errors" in response.json()["detail"].lower() + + async def test_usage_endpoint_reports_errors(self, client): + run_id = str(uuid4()) + from uuid import UUID + run_uuid = UUID(run_id) + + global_error_hash_map.add_inference_error(run_uuid) + global_error_hash_map.add_inference_error(run_uuid) + global_cost_hash_map.add_cost(run_uuid, 0.05) + + with patch("inference_gateway.main.config") as mock_config: + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + response = await client.get(f"/api/usage?evaluation_run_id={run_id}") + + assert response.status_code == 200 + data = response.json() + assert data["inference_errors"] == 2 + assert data["max_inference_errors"] == 5 + assert data["used_cost_usd"] == 0.05 + + async def test_usage_endpoint_zero_errors(self, client): + """A run with no errors should report zero.""" + run_id = str(uuid4()) + + with patch("inference_gateway.main.config") as mock_config: + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + response = await client.get(f"/api/usage?evaluation_run_id={run_id}") + + assert response.status_code == 200 + data = response.json() + assert data["inference_errors"] == 0 + assert data["used_cost_usd"] == 0.0 + + async def test_separate_runs_do_not_interfere(self, client): + """Errors for one run must not affect another run's count.""" + run_a = str(uuid4()) + run_b = str(uuid4()) + mock_provider = _mock_provider(status_code=502) + + with patch("inference_gateway.main.get_provider_that_supports_model_for_inference", return_value=mock_provider), \ + patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = False + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + # Hit run_a three times + for _ in range(3): + await client.post("/api/inference", json={ + "evaluation_run_id": run_a, + "model": "test-model", + "temperature": 0.5, + "messages": [{"role": "user", "content": "test"}] + }) + + # Hit run_b once + await client.post("/api/inference", json={ + "evaluation_run_id": run_b, + "model": "test-model", + "temperature": 0.5, + "messages": [{"role": "user", "content": "test"}] + }) + + from uuid import UUID + assert global_error_hash_map.get_inference_errors(UUID(run_a)) == 3 + assert global_error_hash_map.get_inference_errors(UUID(run_b)) == 1 + + async def test_embedding_platform_error_increments_counter(self, client): + """Embedding endpoint should also track platform errors.""" + run_id = str(uuid4()) + mock_provider = _mock_provider(status_code=503) + + with patch("inference_gateway.main.get_provider_that_supports_model_for_embedding", return_value=mock_provider), \ + patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = False + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5 + + response = await client.post("/api/embedding", json={ + "evaluation_run_id": run_id, + "model": "test-model", + "input": "hello world" + }) + + assert response.status_code == 503 + from uuid import UUID + assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 1 + + async def test_constants_match_expected(self): + """Verify the actual constants match what we test against.""" + assert PLATFORM_ERROR_CODES == EXPECTED_PLATFORM_ERROR_CODES + assert is_platform_error(500) + assert not is_platform_error(400) + + async def test_reset_endpoint_clears_errors(self, client): + """POST /api/reset-inference-errors should clear the error count.""" + run_id = str(uuid4()) + + # Add some errors + from uuid import UUID + global_error_hash_map.add_inference_error(UUID(run_id)) + global_error_hash_map.add_inference_error(UUID(run_id)) + assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 2 + + # Reset via endpoint + response = await client.post(f"/api/reset-inference-errors?evaluation_run_id={run_id}") + assert response.status_code == 200 + + # Errors should be cleared + assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 0 + + async def test_reset_endpoint_allows_new_inferences_after_threshold(self, client): + """After resetting errors, the error count is 0 so the threshold check passes.""" + run_id = str(uuid4()) + from uuid import UUID + + with patch("inference_gateway.main.config") as mock_config: + mock_config.USE_DATABASE = False + mock_config.MAX_COST_PER_EVALUATION_RUN_USD = 10.0 + mock_config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 2 + + # Fill up the error counter to the threshold + global_error_hash_map.add_inference_error(UUID(run_id)) + global_error_hash_map.add_inference_error(UUID(run_id)) + assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 2 + + # Reset errors via endpoint + reset_response = await client.post(f"/api/reset-inference-errors?evaluation_run_id={run_id}") + assert reset_response.status_code == 200 + assert global_error_hash_map.get_inference_errors(UUID(run_id)) == 0 + + # Verify usage endpoint also reflects the reset + usage_response = await client.get(f"/api/usage?evaluation_run_id={run_id}") + assert usage_response.status_code == 200 + assert usage_response.json()["inference_errors"] == 0 + + + +# --------------------------------------------------------------------------- +# Integration test: validator retry logic +# +# This test mocks the sandbox/problem suite to verify that the validator +# retries a single run when inference errors exceed the threshold, rather +# than failing the entire evaluation. +# +# Requires: the full validator environment (sandbox_manager, problem_suites) +# Run with: python3 -m pytest tests/test_inference_error_tracking.py::TestValidatorRetryLogic -v +# --------------------------------------------------------------------------- + +@pytest.mark.skipif( + not os.getenv("NETUID"), + reason="Requires full validator environment (.env with NETUID, etc.)" +) +class TestValidatorRetryLogic: + """Tests for the validator's per-run retry behavior on platform inference errors. + + These tests require the full validator environment to be configured. + Run with: NETUID=... python3 -m pytest tests/test_inference_error_tracking.py::TestValidatorRetryLogic -v + """ + + @pytest.fixture + def mock_httpx_responses(self): + """Create mock httpx responses for the inference gateway usage/reset endpoints.""" + call_count = {"check": 0, "reset": 0} + + class MockResponse: + def __init__(self, status_code, json_data=None): + self.status_code = status_code + self._json = json_data or {} + def json(self): + return self._json + + class MockClient: + async def __aenter__(self): + return self + async def __aexit__(self, *args): + pass + async def get(self, url): + call_count["check"] += 1 + # First call: errors exceed threshold (triggers retry) + # Second call: errors below threshold (retry succeeded) + if call_count["check"] == 1: + return MockResponse(200, {"inference_errors": 5, "max_inference_errors": 5}) + return MockResponse(200, {"inference_errors": 0, "max_inference_errors": 5}) + async def post(self, url): + call_count["reset"] += 1 + return MockResponse(200) + + return MockClient, call_count + + @pytest.mark.asyncio + async def test_retry_resets_errors_and_reruns(self, mock_httpx_responses): + """Validator should retry a run when inference errors exceed threshold.""" + MockClient, call_count = mock_httpx_responses + from unittest.mock import AsyncMock, patch, MagicMock + from uuid import uuid4 + from models.problem import ProblemTestResultStatus + from models.evaluation_run import EvaluationRunStatus + + run_id = uuid4() + problem_name = "test-problem" + agent_code = "print('hello')" + run_count = {"value": 0} + + # Mock problem suite + mock_suite = MagicMock() + mock_suite.has_problem_name.return_value = True + mock_suite.get_problem.return_value = MagicMock() + mock_suite.initialize_agent_sandbox.return_value = MagicMock() + mock_suite.run_agent_sandbox.return_value = ("patch content", "agent logs") + mock_suite.initialize_eval_sandbox.return_value = MagicMock() + + mock_test_result = MagicMock() + mock_test_result.status = ProblemTestResultStatus.PASS + mock_test_result.model_dump.return_value = {"status": "pass", "name": "test"} + mock_suite.run_eval_sandbox.return_value = ([mock_test_result], "eval logs") + + # Track how many times the agent sandbox is initialized (= number of attempts) + original_init = mock_suite.initialize_agent_sandbox + def counting_init(*args, **kwargs): + run_count["value"] += 1 + return original_init(*args, **kwargs) + mock_suite.initialize_agent_sandbox.side_effect = counting_init + + # Import and patch + import validator.main as val_main + + updates = [] + async def mock_update(eid, pname, status, extra=None): + updates.append((status, extra)) + + with patch.object(val_main, "problem_suites", [mock_suite]), \ + patch.object(val_main, "sandbox_manager", MagicMock()), \ + patch.object(val_main, "running_agent_timeout_seconds", 60), \ + patch.object(val_main, "running_eval_timeout_seconds", 60), \ + patch.object(val_main, "update_evaluation_run", mock_update), \ + patch.object(val_main, "truncate_logs_if_required", lambda x: x), \ + patch("httpx.AsyncClient", MockClient): + + await val_main._run_evaluation_run(run_id, problem_name, agent_code) + + # Should have been called twice (first attempt + one retry) + assert run_count["value"] == 2, f"Expected 2 attempts, got {run_count['value']}" + # Should have reset errors once + assert call_count["reset"] == 1, f"Expected 1 reset call, got {call_count['reset']}" + # Should have checked errors twice + assert call_count["check"] == 2, f"Expected 2 check calls, got {call_count['check']}" + # Final status should be finished (not error) + final_status = updates[-1][0] + assert final_status == EvaluationRunStatus.finished, f"Expected finished, got {final_status}" diff --git a/validator/main.py b/validator/main.py index d5f84ab6a..700af772b 100644 --- a/validator/main.py +++ b/validator/main.py @@ -153,6 +153,36 @@ async def _run_evaluation_run_with_semaphore(evaluation_run_id: UUID, problem_na async with semaphore: return await _run_evaluation_run(evaluation_run_id, problem_name, agent_code) +# Maximum number of times to retry a single evaluation run when platform +# inference errors exceed the threshold. After this many retries the run +# is marked as a platform error so the evaluation can still finish. +MAX_SINGLE_RUN_RETRIES = 2 + +async def _check_and_handle_inference_errors(evaluation_run_id: UUID, problem_name: str) -> bool: + """Check if this run hit too many platform inference errors. + Returns True if the error threshold was exceeded (caller should retry).""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + usage_response = await client.get(f"{config.RIDGES_INFERENCE_GATEWAY_URL}/api/usage?evaluation_run_id={evaluation_run_id}") + if usage_response.status_code == 200: + usage = usage_response.json() + inference_errors = usage.get("inference_errors", 0) + max_inference_errors = usage.get("max_inference_errors", float("inf")) + if inference_errors >= max_inference_errors: + logger.warning(f"Run {evaluation_run_id} for {problem_name} hit inference error threshold ({inference_errors}/{max_inference_errors})") + return True + except Exception as e: + logger.warning(f"Failed to check inference error count for evaluation run {evaluation_run_id}: {e}") + return False + +async def _reset_inference_errors(evaluation_run_id: UUID): + """Reset the inference error counter on the gateway before retrying a run.""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + await client.post(f"{config.RIDGES_INFERENCE_GATEWAY_URL}/api/reset-inference-errors?evaluation_run_id={evaluation_run_id}") + except Exception as e: + logger.warning(f"Failed to reset inference error count for evaluation run {evaluation_run_id}: {e}") + # Run an evaluation run async def _run_evaluation_run(evaluation_run_id: UUID, problem_name: str, agent_code: str): try: @@ -175,85 +205,109 @@ async def _run_evaluation_run(evaluation_run_id: UUID, problem_name: str, agent_ logger.info(f"Starting evaluation run {evaluation_run_id} for problem {problem_name}...") - - try: - # Move from pending -> initializing_agent - await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.initializing_agent) - - # Start initializing the agent sandbox - agent_sandbox = await asyncio.to_thread( - problem_suite.initialize_agent_sandbox, - sandbox_manager, - problem, - evaluation_run_id, - agent_code, - running_agent_timeout_seconds, - include_solutions=config.INCLUDE_SOLUTIONS - ) - - # Move from initializing_agent -> running_agent - await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.running_agent) - - # Start running the agent sandbox - patch, agent_logs = await asyncio.to_thread( - problem_suite.run_agent_sandbox, - sandbox_manager, - agent_sandbox - ) - logger.info(f"Finished running agent for problem {problem_name}: {len(patch.splitlines())} lines of patch, {len(agent_logs.splitlines())} lines of agent logs") - - # Move from running_agent -> initializing_eval - await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.initializing_eval, { - "patch": patch, - "agent_logs": truncate_logs_if_required(agent_logs) - }) - - # Start initializing the evaluation sandbox - eval_sandbox = await asyncio.to_thread( - problem_suite.initialize_eval_sandbox, - sandbox_manager, - problem, - evaluation_run_id, - patch, - running_eval_timeout_seconds - ) - - # Move from initializing_eval -> running_eval - await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.running_eval) - - # Start running the evaluation sandbox - test_results, eval_logs = await asyncio.to_thread( - problem_suite.run_eval_sandbox, - sandbox_manager, - eval_sandbox - ) - num_passed = sum(1 for test in test_results if test.status == ProblemTestResultStatus.PASS) - num_failed = sum(1 for test in test_results if test.status == ProblemTestResultStatus.FAIL) - num_skipped = sum(1 for test in test_results if test.status == ProblemTestResultStatus.SKIP) - logger.info(f"Finished running evaluation for problem {problem_name}: {len(test_results)} test results ({num_passed} passed, {num_failed} failed, {num_skipped} skipped), {len(eval_logs.splitlines())} lines of eval logs") - - # Move from running_eval -> finished - await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.finished, { - "test_results": [test.model_dump() for test in test_results], - "eval_logs": truncate_logs_if_required(eval_logs) - }) - - except EvaluationRunException as e: - logger.error(f"Evaluation run {evaluation_run_id} for problem {problem_name} errored: {e}") - - await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.error, { - "error_code": e.error_code.value, - "error_message": e.error_message - }) - - except Exception as e: - logger.error(f"Evaluation run {evaluation_run_id} for problem {problem_name} errored: {EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR.get_error_message()}: {e}") - logger.error(traceback.format_exc()) - - await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.error, { - "error_code": EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR.value, - "error_message": f"{EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR.get_error_message()}: {e}\n\nTraceback:\n{traceback.format_exc()}" - }) + # Retry loop: if the agent is hit by too many platform inference errors, + # reset the error counter and retry just this run (not the whole evaluation). + for attempt in range(1 + MAX_SINGLE_RUN_RETRIES): + + try: + # Move from pending -> initializing_agent + await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.initializing_agent) + + # Start initializing the agent sandbox + agent_sandbox = await asyncio.to_thread( + problem_suite.initialize_agent_sandbox, + sandbox_manager, + problem, + evaluation_run_id, + agent_code, + running_agent_timeout_seconds, + include_solutions=config.INCLUDE_SOLUTIONS + ) + + # Move from initializing_agent -> running_agent + await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.running_agent) + + # Start running the agent sandbox + patch, agent_logs = await asyncio.to_thread( + problem_suite.run_agent_sandbox, + sandbox_manager, + agent_sandbox + ) + logger.info(f"Finished running agent for problem {problem_name}: {len(patch.splitlines())} lines of patch, {len(agent_logs.splitlines())} lines of agent logs") + + # Check if the agent was affected by platform-side inference errors. + # Instead of failing the whole evaluation, retry just this run. + exceeded = await _check_and_handle_inference_errors(evaluation_run_id, problem_name) + if exceeded and attempt < MAX_SINGLE_RUN_RETRIES: + logger.info(f"Retrying run {evaluation_run_id} for {problem_name} (attempt {attempt + 2}/{1 + MAX_SINGLE_RUN_RETRIES}) due to platform inference errors") + await _reset_inference_errors(evaluation_run_id) + continue + elif exceeded: + # Exhausted retries — mark as platform error + raise EvaluationRunException( + EvaluationRunErrorCode.PLATFORM_TOO_MANY_INFERENCE_ERRORS, + f"{EvaluationRunErrorCode.PLATFORM_TOO_MANY_INFERENCE_ERRORS.get_error_message()}: exceeded inference error threshold after {1 + MAX_SINGLE_RUN_RETRIES} attempts", + extra={"agent_logs": truncate_logs_if_required(agent_logs)} + ) + + # Move from running_agent -> initializing_eval + await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.initializing_eval, { + "patch": patch, + "agent_logs": truncate_logs_if_required(agent_logs) + }) + + # Start initializing the evaluation sandbox + eval_sandbox = await asyncio.to_thread( + problem_suite.initialize_eval_sandbox, + sandbox_manager, + problem, + evaluation_run_id, + patch, + running_eval_timeout_seconds + ) + + # Move from initializing_eval -> running_eval + await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.running_eval) + + # Start running the evaluation sandbox + test_results, eval_logs = await asyncio.to_thread( + problem_suite.run_eval_sandbox, + sandbox_manager, + eval_sandbox + ) + num_passed = sum(1 for test in test_results if test.status == ProblemTestResultStatus.PASS) + num_failed = sum(1 for test in test_results if test.status == ProblemTestResultStatus.FAIL) + num_skipped = sum(1 for test in test_results if test.status == ProblemTestResultStatus.SKIP) + logger.info(f"Finished running evaluation for problem {problem_name}: {len(test_results)} test results ({num_passed} passed, {num_failed} failed, {num_skipped} skipped), {len(eval_logs.splitlines())} lines of eval logs") + + # Move from running_eval -> finished + await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.finished, { + "test_results": [test.model_dump() for test in test_results], + "eval_logs": truncate_logs_if_required(eval_logs) + }) + + # Success — break out of the retry loop + break + + except EvaluationRunException as e: + logger.error(f"Evaluation run {evaluation_run_id} for problem {problem_name} errored: {e}") + + await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.error, { + "error_code": e.error_code.value, + "error_message": e.error_message, + **(e.extra or {}) + }) + break + + except Exception as e: + logger.error(f"Evaluation run {evaluation_run_id} for problem {problem_name} errored: {EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR.get_error_message()}: {e}") + logger.error(traceback.format_exc()) + + await update_evaluation_run(evaluation_run_id, problem_name, EvaluationRunStatus.error, { + "error_code": EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR.value, + "error_message": f"{EvaluationRunErrorCode.VALIDATOR_INTERNAL_ERROR.get_error_message()}: {e}\n\nTraceback:\n{traceback.format_exc()}" + }) + break