Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions inference_gateway/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@
logger.fatal("MAX_COST_PER_EVALUATION_RUN_USD is not set in .env")
MAX_COST_PER_EVALUATION_RUN_USD = float(MAX_COST_PER_EVALUATION_RUN_USD)

MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = os.getenv("MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN")
if not MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN:
MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = 5
logger.warning(f"MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN is not set in .env, defaulting to {MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN}")
MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN = int(MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN)



USE_CHUTES = os.getenv("USE_CHUTES")
Expand Down
55 changes: 55 additions & 0 deletions inference_gateway/error_hash_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Tracks the number of platform-side inference errors per evaluation run.
# When the count exceeds a configured threshold the run is flagged as a
# platform error so the agent is not penalized unfairly.

import time

from uuid import UUID
from pydantic import BaseModel



ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS = 60 # 1 minute

class ErrorHashMapEntry(BaseModel):
inference_errors: int
last_accessed_at: float

class ErrorHashMap:
def __init__(self):
self.error_hash_map = {}
self.last_cleanup_at = time.time()



def _cleanup(self):
now = time.time()
if now - self.last_cleanup_at > ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS:
self.error_hash_map = {k: v for k, v in self.error_hash_map.items() if now - v.last_accessed_at < ERROR_HASH_MAP_CLEANUP_INTERVAL_SECONDS}
self.last_cleanup_at = now



def get_inference_errors(self, uuid: UUID) -> int:
self._cleanup()

if uuid in self.error_hash_map:
self.error_hash_map[uuid].last_accessed_at = time.time()
return self.error_hash_map[uuid].inference_errors
else:
return 0

def add_inference_error(self, uuid: UUID):
self._cleanup()

if uuid in self.error_hash_map:
entry = self.error_hash_map[uuid]
entry.inference_errors += 1
entry.last_accessed_at = time.time()
else:
self.error_hash_map[uuid] = ErrorHashMapEntry(inference_errors=1, last_accessed_at=time.time())

def reset_inference_errors(self, uuid: UUID):
"""Reset the error count for a specific evaluation run (used before retrying)."""
if uuid in self.error_hash_map:
del self.error_hash_map[uuid]
60 changes: 57 additions & 3 deletions inference_gateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from contextlib import asynccontextmanager
from models.evaluation_run import EvaluationRunStatus
from inference_gateway.cost_hash_map import CostHashMap
from inference_gateway.error_hash_map import ErrorHashMap
from inference_gateway.providers.provider import Provider
from inference_gateway.providers.chutes import ChutesProvider
from inference_gateway.providers.targon import TargonProvider
Expand Down Expand Up @@ -107,6 +108,18 @@ async def wrapper(*args, **kwargs):


cost_hash_map = CostHashMap()
error_hash_map = ErrorHashMap()



# Platform-side error codes that are not the agent's fault. 4xx errors like
# 400/404/422 are excluded because those indicate bad agent requests (wrong
# model, invalid format, etc). 429 is excluded because it means the cost
# limit was intentionally reached.
PLATFORM_ERROR_CODES = {500, 502, 503, 504, -1}

def is_platform_error(status_code: int) -> bool:
return status_code in PLATFORM_ERROR_CODES



Expand Down Expand Up @@ -158,7 +171,16 @@ async def inference(request: InferenceRequest) -> InferenceResponse:
detail=f"The evaluation run with ID {request.evaluation_run_id} has reached or exceeded the evaluation run cost limit of {config.MAX_COST_PER_EVALUATION_RUN_USD} USD (current cost: {cost} USD)."
)


# Make sure the evaluation run has not had too many platform-side inference errors
inference_errors = error_hash_map.get_inference_errors(request.evaluation_run_id)
if inference_errors >= config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN:
logger.warning(f"Blocking inference for run {request.evaluation_run_id}: too many platform errors ({inference_errors}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})")
raise HTTPException(
status_code=503,
detail=f"The evaluation run with ID {request.evaluation_run_id} has had too many platform-side inference errors ({inference_errors} errors, limit is {config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})."
)



# Make sure we support the model for inference
provider = get_provider_that_supports_model_for_inference(request.model)
Expand Down Expand Up @@ -206,6 +228,12 @@ async def inference(request: InferenceRequest) -> InferenceResponse:
tool_calls=response.tool_calls
)
else:
# Track platform errors (provider failures, not agent mistakes)
if is_platform_error(response.status_code):
error_hash_map.add_inference_error(request.evaluation_run_id)
error_count = error_hash_map.get_inference_errors(request.evaluation_run_id)
logger.warning(f"Platform inference error for run {request.evaluation_run_id}: status {response.status_code} (error {error_count}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})")

raise HTTPException(
status_code=response.status_code,
detail=response.error_message
Expand Down Expand Up @@ -245,7 +273,16 @@ async def embedding(request: EmbeddingRequest) -> EmbeddingResponse:
detail=f"The evaluation run with ID {request.evaluation_run_id} has reached or exceeded the evaluation run cost limit of {config.MAX_COST_PER_EVALUATION_RUN_USD} USD (current cost: {cost} USD)."
)


# Make sure the evaluation run has not had too many platform-side inference errors
inference_errors = error_hash_map.get_inference_errors(request.evaluation_run_id)
if inference_errors >= config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN:
logger.warning(f"Blocking embedding for run {request.evaluation_run_id}: too many platform errors ({inference_errors}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})")
raise HTTPException(
status_code=503,
detail=f"The evaluation run with ID {request.evaluation_run_id} has had too many platform-side inference errors ({inference_errors} errors, limit is {config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})."
)



# Make sure we support the model for embedding
provider = get_provider_that_supports_model_for_embedding(request.model)
Expand Down Expand Up @@ -287,6 +324,12 @@ async def embedding(request: EmbeddingRequest) -> EmbeddingResponse:
embedding=response.embedding
)
else:
# Track platform errors (provider failures, not agent mistakes)
if is_platform_error(response.status_code):
error_hash_map.add_inference_error(request.evaluation_run_id)
error_count = error_hash_map.get_inference_errors(request.evaluation_run_id)
logger.warning(f"Platform embedding error for run {request.evaluation_run_id}: status {response.status_code} (error {error_count}/{config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN})")

raise HTTPException(
status_code=response.status_code,
detail=response.error_message
Expand All @@ -298,6 +341,8 @@ class UsageResponse(BaseModel):
used_cost_usd: float
remaining_cost_usd: float
max_cost_usd: float
inference_errors: int
max_inference_errors: int

@app.get("/api/usage")
@handle_http_exceptions
Expand All @@ -306,9 +351,18 @@ async def usage(evaluation_run_id: UUID) -> UsageResponse:
return UsageResponse(
used_cost_usd=used_cost_usd,
remaining_cost_usd=config.MAX_COST_PER_EVALUATION_RUN_USD - used_cost_usd,
max_cost_usd=config.MAX_COST_PER_EVALUATION_RUN_USD
max_cost_usd=config.MAX_COST_PER_EVALUATION_RUN_USD,
inference_errors=error_hash_map.get_inference_errors(evaluation_run_id),
max_inference_errors=config.MAX_INFERENCE_ERRORS_PER_EVALUATION_RUN
)

@app.post("/api/reset-inference-errors")
@handle_http_exceptions
async def reset_inference_errors(evaluation_run_id: UUID):
"""Reset the inference error count for a specific evaluation run before retrying."""
error_hash_map.reset_inference_errors(evaluation_run_id)
return {"status": "ok"}



@app.get("/api/inference-models")
Expand Down
1 change: 1 addition & 0 deletions models/evaluation_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __new__(cls, code: int, message: str):
PLATFORM_RESTARTED_WHILE_RUNNING_AGENT = (3020, "The platform was restarted while the evaluation run was running the agent")
PLATFORM_RESTARTED_WHILE_INIT_EVAL = (3030, "The platform was restarted while the evaluation run was initializing the evaluation")
PLATFORM_RESTARTED_WHILE_RUNNING_EVAL = (3040, "The platform was restarted while the evaluation run was running the evaluation")
PLATFORM_TOO_MANY_INFERENCE_ERRORS = (3050, "Too many platform-side inference errors occurred during the evaluation run")

def get_error_message(self) -> str:
return self.message
Expand Down
Empty file added tests/__init__.py
Empty file.
Loading